In [1]:

from google.colab import drive
#drive.mount('/content/drive/')
import os
project_path = '/content/drive/MyDrive/Cornell/pvz'
os.chdir(project_path)
print("当前工作路径：", os.getcwd())


当前工作路径： /content/drive/MyDrive/Cornell/pvz


In [2]:
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision import transforms, models
import torchvision.transforms.functional as TF
from sklearn.model_selection import StratifiedKFold
import os
import time

# Set random seed for reproducibility
# torch.manual_seed(42)
# np.random.seed(42)
# if torch.cuda.is_available():
#     torch.cuda.manual_seed_all(42)

## Configuration

In [3]:
# Data paths
TRAIN_PKL_PATH = 'train.pkl'
TEST_PKL_PATH = 'test.pkl'
MODEL_SAVE_PATH = 'best_siamese_resnet_acc.pth'
SUBMISSION_CSV_PATH = 'submission_siamese_resnet_acc.csv'

# Hyperparameters
BATCH_SIZE = 64
LEARNING_RATE = 1e-4 # Adjusted based on previous suggestion
WEIGHT_DECAY = 1e-4 # Adjusted based on previous suggestion
EPOCHS = 80 # Max epochs
PATIENCE_LR = 5
PATIENCE_ES = 10 # Early stopping patience
VALIDATION_SPLIT = 0.2 # Use 20% for validation
USE_PRETRAINED_BASE = True # Use ImageNet weights for base ResNet?

# Inference Configuration
INFERENCE_BATCH_SIZE = 128 # Can be larger for inference

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

Using device: cuda


## Data Handling

In [4]:
# Data Augmentation Definition (for Training)
augment_transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(24, padding=2), # Consider if this crop is too aggressive
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)) # Normalize single channel
])

# Dataset
class RPSSiameseDataset(Dataset):
    def __init__(self, pkl_path, transform=None):
        self.imgs1 = None
        self.imgs2 = None
        self.labels = None
        self.transform = transform

        try:
            with open(pkl_path, 'rb') as f:
                data = pickle.load(f)
            print(f"Pickle file '{pkl_path}' loaded successfully.")

            print("Attempting to stack 'img1' data...")
            self.imgs1 = np.stack(data['img1']).astype(np.uint8) # Stack and ensure uint8 for PIL
            print(f"  'img1' stacked successfully. Shape: {self.imgs1.shape}")

            print("Attempting to stack 'img2' data...")
            self.imgs2 = np.stack(data['img2']).astype(np.uint8) # Stack and ensure uint8 for PIL
            print(f"  'img2' stacked successfully. Shape: {self.imgs2.shape}")

            labels_raw = np.array(data['label'])
            self.labels = torch.tensor((labels_raw == 1).astype(np.int64)) # 1 if img1 beats img2, else 0

            assert len(self.imgs1) == len(self.labels), "Mismatch between img1 count and labels count."
            assert len(self.imgs2) == len(self.labels), "Mismatch between img2 count and labels count."
            assert self.imgs1.shape[1:] == (24, 24), f"img1 shape error: {self.imgs1.shape}"
            assert self.imgs2.shape[1:] == (24, 24), f"img2 shape error: {self.imgs2.shape}"

            print(f"Dataset initialized successfully from {pkl_path}: {len(self.labels)} samples.")

        except FileNotFoundError:
            print(f"Error: File not found at {pkl_path}")
        except Exception as e:
            print(f"Error during dataset initialization from {pkl_path}: {e}")

            self.imgs1, self.imgs2, self.labels = None, None, None


    def __len__(self):
        return len(self.labels) if self.labels is not None else 0

    def __getitem__(self, idx):
        if self.imgs1 is None or self.imgs2 is None:
             raise IndexError("Dataset not initialized correctly.")

        im1_pil = Image.fromarray(self.imgs1[idx]) # Already uint8 from __init__
        im2_pil = Image.fromarray(self.imgs2[idx]) # Already uint8 from __init__
        y = self.labels[idx]

        # Apply independent transforms
        if self.transform:
            im1 = self.transform(im1_pil)
            im2 = self.transform(im2_pil)
        else:
            to_tensor = transforms.ToTensor()
            im1 = to_tensor(im1_pil)
            im2 = to_tensor(im2_pil)

        return im1, im2, y

# Dataset for Inference (Corrected Loading)
class RPSInferenceDataset(Dataset):
    def __init__(self, pkl_path, ids_key='id', img1_key='img1', img2_key='img2'):
        self.ids = None
        self.imgs1 = None
        self.imgs2 = None
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

        try:
            with open(pkl_path, 'rb') as f:
                 data = pickle.load(f)
            print(f"Pickle file '{pkl_path}' loaded successfully for inference.")

            self.ids = data.get(ids_key)
            if self.ids is None:
                 raise ValueError(f"Key '{ids_key}' not found in pickle file.")
            self.ids = np.array(self.ids) # Ensure IDs are numpy array


            print("Attempting to stack 'img1' data for inference...")
            img1_data = data.get(img1_key)
            if img1_data is None: raise ValueError(f"Key '{img1_key}' not found.")
            self.imgs1 = np.stack(img1_data).astype(np.uint8)
            print(f"  'img1' stacked successfully. Shape: {self.imgs1.shape}")


            print("Attempting to stack 'img2' data for inference...")
            img2_data = data.get(img2_key)
            if img2_data is None: raise ValueError(f"Key '{img2_key}' not found.")
            self.imgs2 = np.stack(img2_data).astype(np.uint8)
            print(f"  'img2' stacked successfully. Shape: {self.imgs2.shape}")

            # Validation checks
            assert len(self.imgs1) == len(self.ids), "Mismatch between img1 count and ID count."
            assert len(self.imgs2) == len(self.ids), "Mismatch between img2 count and ID count."
            assert self.imgs1.shape[1:] == (24, 24), f"img1 shape error: {self.imgs1.shape}"
            assert self.imgs2.shape[1:] == (24, 24), f"img2 shape error: {self.imgs2.shape}"

            print(f"Inference dataset initialized successfully from {pkl_path}: {len(self.ids)} samples.")


        except FileNotFoundError:
            print(f"Error: File not found at {pkl_path}")
        except Exception as e:
            print(f"Error during inference dataset initialization from {pkl_path}: {e}")
            self.ids, self.imgs1, self.imgs2 = None, None, None


    def __len__(self):
        return len(self.ids) if self.ids is not None else 0

    def __getitem__(self, idx):
        if self.imgs1 is None or self.imgs2 is None or self.ids is None:
             raise IndexError("Inference dataset not initialized correctly.")

        im1_pil = Image.fromarray(self.imgs1[idx])
        im2_pil = Image.fromarray(self.imgs2[idx])
        current_id = self.ids[idx]

        # Apply only ToTensor and Normalize
        im1 = self.transform(im1_pil)
        im2 = self.transform(im2_pil)

        return im1, im2, current_id

print("Corrected Dataset classes defined.")

Corrected Dataset classes defined.


## Model Definition

In [5]:
# Base Network (resnet18 for Feature Extractor)
def get_base_resnet18(pretrained=True):
    weights = models.ResNet18_Weights.DEFAULT if pretrained else None
    backbone = models.resnet18(weights=weights)
    original_conv1 = backbone.conv1
    backbone.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    if pretrained and original_conv1.weight.shape[1] == 3:
        new_weights = original_conv1.weight.data.mean(dim=1, keepdim=True)
        backbone.conv1.weight.data = new_weights
        # print("Adapted pretrained weights for conv1 (1 channel input).") # Optional print

    num_ftrs = backbone.fc.in_features
    backbone.fc = nn.Identity() # Remove final classification layer

    return backbone, num_ftrs

# Siamese Network
class SiameseNet(nn.Module):
    def __init__(self, pretrained_base=True):
        super().__init__()
        self.base_network, num_base_ftrs = get_base_resnet18(pretrained=pretrained_base)
        self.classifier_head = nn.Sequential(
            nn.Linear(num_base_ftrs * 2, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, 2) # 2 classes for CrossEntropyLoss (0 or 1)
        )

    def forward(self, input1, input2):
        feat1 = self.base_network(input1)
        feat2 = self.base_network(input2)
        combined_features = torch.cat((feat1, feat2), dim=1)
        output = self.classifier_head(combined_features)
        return output

print("Model classes defined.")

Model classes defined.


## Training

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train() # Set model to training mode
    total_loss = correct = total = 0
    start_time = time.time()
    for batch_idx, (im1, im2, y) in enumerate(loader):
        im1, im2, y = im1.to(device), im2.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(im1, im2)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * im1.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

        # Optional: Print progress within epoch
        # if batch_idx % 50 == 0:
        #     print(f"  Batch {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}")

    epoch_time = time.time() - start_time
    avg_loss = total_loss / total if total > 0 else 0
    avg_acc = correct / total if total > 0 else 0
    print(f"  Train Time: {epoch_time:.2f}s")
    return avg_loss, avg_acc

def validate(model, loader, criterion, device):
    model.eval() # Set model to evaluation mode
    total_loss = correct = total = 0
    with torch.no_grad(): # Disable gradient calculation
        for im1, im2, y in loader:
            im1, im2, y = im1.to(device), im2.to(device), y.to(device)
            logits = model(im1, im2)
            loss = criterion(logits, y)
            total_loss += loss.item() * im1.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    avg_loss = total_loss / total if total > 0 else 0
    avg_acc = correct / total if total > 0 else 0
    return avg_loss, avg_acc

print("Training/validation functions defined.")

print("\n Starting Training Phase")

full_dataset = RPSSiameseDataset(TRAIN_PKL_PATH, transform=augment_transform)
if len(full_dataset) == 0:
    print("Training aborted: Could not load training data.")
else:
    n_total = len(full_dataset)
    n_val = int(n_total * VALIDATION_SPLIT)
    n_train = n_total - n_val

    if n_val == 0 and n_total > 0: # Ensure validation set is not empty if possible
        n_train = max(1, n_total - 1)
        n_val = n_total - n_train
        print(f"Warning: Validation split resulted in 0 samples. Using {n_val} sample for validation.")

    print(f"Splitting data: {n_train} train, {n_val} validation")
    try:
        train_dataset, val_dataset = random_split(full_dataset, [n_train, n_val])
    except ValueError as e:
         print(f"Error during random_split: {e}. Ensure dataset has samples and split is valid.")
         train_dataset, val_dataset = None, None # Abort training if split fails

    if train_dataset and val_dataset:
        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=os.cpu_count()//2, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=os.cpu_count()//2, pin_memory=True)
        print(f"DataLoaders created: Train batches={len(train_loader)}, Val batches={len(val_loader)}")

        # Initialize Model, Loss, Optimizer, Scheduler
        model = SiameseNet(pretrained_base=USE_PRETRAINED_BASE).to(DEVICE)
        print(f"Model: SiameseNet with {'pretrained' if USE_PRETRAINED_BASE else 'random'} base ResNet-18 loaded to {DEVICE}.")

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=PATIENCE_LR, verbose=True
        )

        # Training Loop
        best_val_acc = 0.0
        epochs_no_improve = 0
        training_start_time = time.time()

        for epoch in range(1, EPOCHS + 1):
            print(f"\nEpoch {epoch}/{EPOCHS}")
            tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)

            # Validation
            if len(val_loader) > 0:
                va_loss, va_acc = validate(model, val_loader, criterion, DEVICE)
                current_lr = optimizer.param_groups[0]['lr']
                print(f"  Epoch {epoch:2d} Summary | Train Loss: {tr_loss:.4f}, Acc: {tr_acc:.4f} | Val Loss: {va_loss:.4f}, Acc: {va_acc:.4f} | LR: {current_lr:.1e}")

                scheduler.step(va_acc) # Scheduler steps based on validation accuracy

                if va_acc > best_val_acc:
                    print(f"  🚀 Validation accuracy improved from {best_val_acc:.4f} to {va_acc:.4f}. Saving model...")
                    best_val_acc = va_acc
                    torch.save(model.state_dict(), MODEL_SAVE_PATH)
                    epochs_no_improve = 0
                else:
                    epochs_no_improve += 1
                    print(f"  ⏳ Validation accuracy did not improve for {epochs_no_improve}/{PATIENCE_ES} epochs.")
                    if epochs_no_improve >= PATIENCE_ES:
                        print(f"  🚨 Early stopping triggered after epoch {epoch}. Best validation accuracy: {best_val_acc:.4f}")
                        break
            else:
                print(f"  Epoch {epoch:2d} Summary | Train Loss: {tr_loss:.4f}, Acc: {tr_acc:.4f} | (No validation set)")
                # torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")

        training_duration = time.time() - training_start_time
        print(f"\n--- Training Finished ---")
        print(f"Total Training Time: {training_duration:.2f}s")
        print(f"Best Validation Accuracy: {best_val_acc:.4f}")
        print(f"Best model saved to: {MODEL_SAVE_PATH}")


print("\n--- Starting Evaluation Phase (on Validation Set) ---")

# Load Best Model
if os.path.exists(MODEL_SAVE_PATH):
    eval_model = SiameseNet(pretrained_base=USE_PRETRAINED_BASE).to(DEVICE) # Re-create model structure
    try:
        eval_model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
        eval_model.eval() # Set to evaluation mode
        print(f"Loaded best model state from {MODEL_SAVE_PATH}")

        # Evaluate on the Validation Set (using no augmentation)
        if 'val_dataset' in locals() and val_dataset is not None and len(val_dataset) > 0:
             # Create a loader for validation set *without* augmentation for consistent eval
             val_eval_loader = DataLoader(val_dataset, batch_size=INFERENCE_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count()//2, pin_memory=True)

             # Need criterion for the validate function, even if just for loss calculation
             eval_criterion = nn.CrossEntropyLoss()
             val_eval_loss, val_eval_acc = validate(eval_model, val_eval_loader, eval_criterion, DEVICE)
             print(f"Evaluation Accuracy on Validation Set (Best Model): {val_eval_acc:.4f}")
             print(f"Evaluation Loss on Validation Set (Best Model):   {val_eval_loss:.4f}")
        else:
             print("Skipping evaluation on validation set (validation set not available or empty).")

    except Exception as e:
        print(f"Error loading or evaluating model: {e}")
else:
    print(f"Skipping evaluation: Model file not found at {MODEL_SAVE_PATH}")


## Inference

In [None]:
print("\n Starting Inference Phase")

# Load Best Model (if not already loaded for evaluation)
if 'eval_model' not in locals() or eval_model is None:
    if os.path.exists(MODEL_SAVE_PATH):
        model_for_inference = SiameseNet(pretrained_base=USE_PRETRAINED_BASE).to(DEVICE)
        try:
            model_for_inference.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
            print(f"Loaded best model state from {MODEL_SAVE_PATH} for inference.")
        except Exception as e:
            print(f"Error loading model for inference: {e}")
            model_for_inference = None
    else:
        print(f"Cannot perform inference: Model file not found at {MODEL_SAVE_PATH}")
        model_for_inference = None
else:
     model_for_inference = eval_model # Reuse model loaded during evaluation
     print("Using model already loaded during evaluation phase.")

#  Load Test Data
if model_for_inference:
    model_for_inference.eval() # Ensure model is in eval mode

    test_dataset = RPSInferenceDataset(TEST_PKL_PATH)
    if len(test_dataset) > 0:
        test_loader = DataLoader(test_dataset, batch_size=INFERENCE_BATCH_SIZE, shuffle=False, num_workers=os.cpu_count()//2)
        print(f"Test DataLoader created: {len(test_loader)} batches.")

        #  Perform Inference
        all_preds = []
        all_ids = []
        inference_start_time = time.time()
        with torch.no_grad():
            for im1, im2, ids_batch in test_loader:
                im1, im2 = im1.to(DEVICE), im2.to(DEVICE)
                logits = model_for_inference(im1, im2)
                preds = logits.argmax(dim=1).cpu().numpy()
                all_preds.append(preds)
                all_ids.extend(ids_batch.numpy() if isinstance(ids_batch, torch.Tensor) else ids_batch) # Handle ids if they are tensors or lists

        inference_duration = time.time() - inference_start_time
        print(f"Inference completed in {inference_duration:.2f}s")

        if all_preds:
             final_preds = np.concatenate(all_preds)
        else:
             final_preds = np.array([])
             print("Warning: No predictions were generated.")

        # Post-processing & Submission
        print("\n--- Starting Post-processing & Submission ---")

        if len(final_preds) == len(all_ids):
            # Map predictions (0/1) back to labels (-1/+1)
            final_labels = np.where(final_preds == 1, 1, -1)

            # Create submission DataFrame
            submission_df = pd.DataFrame({'id': all_ids, 'label': final_labels})

            # Save to CSV
            try:
                submission_df.to_csv(SUBMISSION_CSV_PATH, index=False)
                print(f"✅ Submission file saved successfully to: {SUBMISSION_CSV_PATH}")
                # Display first few rows
                print("\nSubmission file preview:")
                print(submission_df.head())
            except Exception as e:
                print(f"Error saving submission file: {e}")
        else:
            print(f"Error: Number of predictions ({len(final_preds)}) does not match number of IDs ({len(all_ids)}). Cannot create submission file.")

    else:
        print("Skipping inference: Test dataset could not be loaded or is empty.")
else:
    print("Skipping inference: Model not loaded.")

print("\n--- Notebook Execution Finished ---")