In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet50
from PIL import Image
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score, accuracy_score
from collections import Counter
import os
import torch.nn.functional as F

root_dir = r'C:\Users\KRIKKER\Course_py\spbu_dl\one-piece\splitted'
train_dir = os.path.join(root_dir, 'train')
test_dir = os.path.join(root_dir, 'test')
print(f"train_dir exists: {os.path.exists(train_dir)}, classes: {len(os.listdir(train_dir)) if os.path.exists(train_dir) else 'N/A'}")
print(f"test_dir exists: {os.path.exists(test_dir)}, sample files: {os.listdir(test_dir)[:5] if os.path.exists(test_dir) else 'N/A'}")

class_to_idx = {
    'Ace': 0, 'Akainu': 1, 'Brook': 2, 'Chopper': 3, 'Crocodile': 4, 'Franky': 5, 'Jinbei': 6, 'Kurohige': 7,
    'Law': 8, 'Luffy': 9, 'Mihawk': 10, 'Nami': 11, 'Rayleigh': 12, 'Robin': 13, 'Sanji': 14, 'Shanks': 15, 'Usopp': 16, 'Zoro': 17
}

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, submission_csv, test_dir, transform=None):
        self.sub_df = pd.read_csv(submission_csv, dtype={'id': str})
        self.test_dir = test_dir
        self.transform = transform
        self.possible_exts = ['.jpg', '.jpeg', '.png']
        print(f"Loaded {len(self.sub_df)} test IDs, sample: {self.sub_df['id'].head(3).tolist()}")

    def __len__(self):
        return len(self.sub_df)

    def __getitem__(self, idx):
        row = self.sub_df.iloc[idx]
        image_id = str(row['id'])
        
        img_path = None
        for ext in self.possible_exts:
            candidate = os.path.join(self.test_dir, image_id + ext)
            if os.path.exists(candidate):
                img_path = candidate
                break
        if img_path is None:
            print(f"Warning: No image for {image_id}. Sample test files: {os.listdir(self.test_dir)[:5]}")
            raise FileNotFoundError(f"Image not found for {image_id}")
        
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        return image, image_id

class ResNet50(nn.Module):
    def __init__(self, num_classes=18):
        super(ResNet50, self).__init__()
        self.model = resnet50(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        for param in self.model.parameters():
            param.requires_grad = False
        for param in self.model.fc.parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.model(x)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.75, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.5)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset_full = ImageFolder(train_dir, transform=train_transform)
print(f"Loaded {len(train_dataset_full)} train images across {len(train_dataset_full.classes)} classes: {train_dataset_full.classes}")

expected_classes = sorted(class_to_idx.keys())
if train_dataset_full.classes != expected_classes:
    print("Warning: Class order mismatch! Remapping targets...")
    new_targets = [class_to_idx[train_dataset_full.classes[label]] for label in train_dataset_full.targets]
    train_dataset_full.targets = new_targets
    print("Remapped OK.")

train_labels = train_dataset_full.targets
class_counts = Counter(train_labels)
print("Class distribution:", dict(sorted(class_counts.items())))
total_samples = len(train_labels)
class_weights = {label: total_samples / (len(class_counts) * count) for label, count in class_counts.items()}
weights_tensor = torch.tensor([class_weights.get(i, 1.0) for i in range(18)]).to('cuda' if torch.cuda.is_available() else 'cpu')
print("Class weights:", weights_tensor.tolist()[:5], "...")

train_size = int(0.8 * len(train_dataset_full))
val_size = len(train_dataset_full) - train_size
train_dataset, val_dataset = random_split(train_dataset_full, [train_size, val_size])
val_dataset.dataset.transform = test_transform

train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=0)  # Smaller batch for ResNet50
val_loader = DataLoader(val_dataset, batch_size=12, shuffle=False, num_workers=0)

submission_csv = r'C:\Users\KRIKKER\Course_py\spbu_dl\one-piece\submission.csv'
test_dataset = TestDataset(submission_csv, test_dir, test_transform)
test_loader = DataLoader(test_dataset, batch_size=12, shuffle=False, num_workers=0)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
model = ResNet50(num_classes=18).to(device)

optimizer = optim.AdamW(model.parameters(), lr=0.00005, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=weights_tensor)
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        if self.alpha is not None:
            focal_loss = self.alpha[targets] * focal_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss.sum()

criterion = FocalLoss(alpha=weights_tensor, gamma=2.0)

def mixup_data(x, y, alpha=0.2):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

num_epochs = 50
best_f1 = 0.0
patience = 12
counter = 0
label_smoothing = 0.1

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        if np.random.rand() < 0.5:
            images, targets_a, targets_b, lam = mixup_data(images, labels, alpha=0.2)
            outputs = model(images)
            loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
        else:
            outputs = model(images)
            targets = F.one_hot(labels, num_classes=18).float().to(device)
            targets = targets * (1 - label_smoothing) + (label_smoothing / 18.0)
            loss = F.kl_div(F.log_softmax(outputs, dim=1), targets, reduction='batchmean')
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Grad clip
        optimizer.step()
        running_loss += loss.item()
    
    scheduler.step()
    
    if epoch == 15:
        for param in model.model.parameters():
            param.requires_grad = True
        print("Unfroze base layers for full fine-tune.")
        optimizer = optim.AdamW(model.parameters(), lr=0.00001)

    model.eval()
    val_preds, val_labels = [], []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            val_preds.extend(preds.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    val_f1 = f1_score(val_labels, val_preds, average='weighted')
    val_acc = accuracy_score(val_labels, val_preds)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Val F1: {val_f1:.4f}, Val Acc: {val_acc:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}')
    
    if val_f1 > best_f1:
        best_f1 = val_f1
        counter = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        counter += 1
        if counter >= patience:
            print(f'Early stopping at epoch {epoch+1}')
            break

model.load_state_dict(torch.load('best_model.pth', weights_only=True))
model.eval()
predictions = []
ids = []
with torch.no_grad():
    for images, image_ids in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, pred_labels = torch.max(outputs, 1)
        predictions.extend(pred_labels.cpu().numpy())
        ids.extend(image_ids)

sub_df = pd.read_csv(submission_csv, dtype={'id': str})
sub_df['label'] = predictions
sub_df.to_csv('submission.csv', index=False)
print('Updated submission.csv saved!')
print(f"Sample: {sub_df.head()}")
print(f"Pred class dist: {np.bincount(predictions)}")
print(f"Final best Val F1: {best_f1:.4f}, Acc: {val_acc:.4f}")

train_dir exists: True, classes: 18
test_dir exists: True, sample files: ['002999fb-803d-45b2-b647-6203e446e475.png', '00479183-ca4e-49c1-927d-f7209ff23145.png', '0048719e-24d9-4a5a-af75-ca62c2e90477.png', '00c99fe0-e1c9-48ef-8d16-1ea380f190cd.png', '00efbf59-d4c4-4d7e-8bc6-86af36367515.png']
Loaded 2915 train images across 18 classes: ['Ace', 'Akainu', 'Brook', 'Chopper', 'Crocodile', 'Franky', 'Jinbei', 'Kurohige', 'Law', 'Luffy', 'Mihawk', 'Nami', 'Rayleigh', 'Robin', 'Sanji', 'Shanks', 'Usopp', 'Zoro']
Class distribution: {0: 168, 1: 167, 2: 178, 3: 170, 4: 167, 5: 170, 6: 166, 7: 170, 8: 175, 9: 97, 10: 167, 11: 181, 12: 167, 13: 167, 14: 135, 15: 168, 16: 170, 17: 132}
Class weights: [0.9639550447463989, 0.9697272181510925, 0.9098002314567566, 0.9526143670082092, 0.9697272181510925] ...
Loaded 849 test IDs, sample: ['c41628b1-4781-4392-ac8d-6bfe981f73f9', 'f114acb3-fe18-478b-a19a-1f4cbe098851', 'd952ecfe-750c-44b2-96c2-1cac1a4ee146']
Using device: cuda




Epoch [1/50], Loss: 2.4268, Val F1: 0.2298, Val Acc: 0.2419, LR: 0.000050




Epoch [2/50], Loss: 2.1954, Val F1: 0.3785, Val Acc: 0.3859, LR: 0.000050




Epoch [3/50], Loss: 2.0442, Val F1: 0.4871, Val Acc: 0.5077, LR: 0.000050




Epoch [4/50], Loss: 1.9047, Val F1: 0.5382, Val Acc: 0.5489, LR: 0.000049




Epoch [5/50], Loss: 1.7719, Val F1: 0.6286, Val Acc: 0.6346, LR: 0.000049




Epoch [6/50], Loss: 1.6526, Val F1: 0.6101, Val Acc: 0.6158, LR: 0.000048




Epoch [7/50], Loss: 1.5695, Val F1: 0.6813, Val Acc: 0.6861, LR: 0.000048




Epoch [8/50], Loss: 1.4617, Val F1: 0.7088, Val Acc: 0.7101, LR: 0.000047




Epoch [9/50], Loss: 1.4303, Val F1: 0.6938, Val Acc: 0.6964, LR: 0.000046




Epoch [10/50], Loss: 1.2982, Val F1: 0.7139, Val Acc: 0.7153, LR: 0.000045




Epoch [11/50], Loss: 1.2869, Val F1: 0.7277, Val Acc: 0.7290, LR: 0.000044




Epoch [12/50], Loss: 1.2445, Val F1: 0.7388, Val Acc: 0.7376, LR: 0.000043




Epoch [13/50], Loss: 1.1918, Val F1: 0.7357, Val Acc: 0.7427, LR: 0.000042




Epoch [14/50], Loss: 1.0974, Val F1: 0.7500, Val Acc: 0.7530, LR: 0.000041




Epoch [15/50], Loss: 1.1070, Val F1: 0.7371, Val Acc: 0.7427, LR: 0.000040




Unfroze base layers for full fine-tune.
Epoch [16/50], Loss: 1.0690, Val F1: 0.7547, Val Acc: 0.7581, LR: 0.000010




Epoch [17/50], Loss: 0.8240, Val F1: 0.8730, Val Acc: 0.8748, LR: 0.000010




Epoch [18/50], Loss: 0.4885, Val F1: 0.9147, Val Acc: 0.9142, LR: 0.000010




Epoch [19/50], Loss: 0.3499, Val F1: 0.9212, Val Acc: 0.9211, LR: 0.000010




Epoch [20/50], Loss: 0.3259, Val F1: 0.9331, Val Acc: 0.9331, LR: 0.000010




Epoch [21/50], Loss: 0.2974, Val F1: 0.9330, Val Acc: 0.9331, LR: 0.000010




Epoch [22/50], Loss: 0.2734, Val F1: 0.9292, Val Acc: 0.9297, LR: 0.000010




Epoch [23/50], Loss: 0.2589, Val F1: 0.9378, Val Acc: 0.9383, LR: 0.000010




Epoch [24/50], Loss: 0.2550, Val F1: 0.9384, Val Acc: 0.9383, LR: 0.000010




Epoch [25/50], Loss: 0.2295, Val F1: 0.9414, Val Acc: 0.9417, LR: 0.000010




Epoch [26/50], Loss: 0.2430, Val F1: 0.9484, Val Acc: 0.9485, LR: 0.000010




Epoch [27/50], Loss: 0.2336, Val F1: 0.9480, Val Acc: 0.9485, LR: 0.000010




Epoch [28/50], Loss: 0.1827, Val F1: 0.9496, Val Acc: 0.9503, LR: 0.000010




Epoch [29/50], Loss: 0.2077, Val F1: 0.9463, Val Acc: 0.9468, LR: 0.000010




Epoch [30/50], Loss: 0.2005, Val F1: 0.9587, Val Acc: 0.9588, LR: 0.000010




Epoch [31/50], Loss: 0.1598, Val F1: 0.9551, Val Acc: 0.9554, LR: 0.000010




Epoch [32/50], Loss: 0.2445, Val F1: 0.9408, Val Acc: 0.9417, LR: 0.000010




Epoch [33/50], Loss: 0.1973, Val F1: 0.9520, Val Acc: 0.9520, LR: 0.000010




Epoch [34/50], Loss: 0.2143, Val F1: 0.9604, Val Acc: 0.9605, LR: 0.000010




Epoch [35/50], Loss: 0.2020, Val F1: 0.9535, Val Acc: 0.9537, LR: 0.000010




Epoch [36/50], Loss: 0.2440, Val F1: 0.9605, Val Acc: 0.9605, LR: 0.000010




Epoch [37/50], Loss: 0.2011, Val F1: 0.9480, Val Acc: 0.9485, LR: 0.000010




Epoch [38/50], Loss: 0.1802, Val F1: 0.9518, Val Acc: 0.9520, LR: 0.000010




Epoch [39/50], Loss: 0.1928, Val F1: 0.9639, Val Acc: 0.9640, LR: 0.000010




Epoch [40/50], Loss: 0.2216, Val F1: 0.9566, Val Acc: 0.9571, LR: 0.000010




Epoch [41/50], Loss: 0.1561, Val F1: 0.9552, Val Acc: 0.9554, LR: 0.000010




Epoch [42/50], Loss: 0.1766, Val F1: 0.9553, Val Acc: 0.9554, LR: 0.000010




Epoch [43/50], Loss: 0.1949, Val F1: 0.9600, Val Acc: 0.9605, LR: 0.000010




Epoch [44/50], Loss: 0.1917, Val F1: 0.9621, Val Acc: 0.9623, LR: 0.000010




Epoch [45/50], Loss: 0.1985, Val F1: 0.9607, Val Acc: 0.9605, LR: 0.000010




Epoch [46/50], Loss: 0.1811, Val F1: 0.9590, Val Acc: 0.9588, LR: 0.000010




Epoch [47/50], Loss: 0.1334, Val F1: 0.9692, Val Acc: 0.9691, LR: 0.000010




Epoch [48/50], Loss: 0.1630, Val F1: 0.9605, Val Acc: 0.9605, LR: 0.000010




Epoch [49/50], Loss: 0.1803, Val F1: 0.9602, Val Acc: 0.9605, LR: 0.000010




Epoch [50/50], Loss: 0.1963, Val F1: 0.9657, Val Acc: 0.9657, LR: 0.000010




Updated submission.csv saved!
Sample:                                      id  label
0  c41628b1-4781-4392-ac8d-6bfe981f73f9     10
1  f114acb3-fe18-478b-a19a-1f4cbe098851      7
2  d952ecfe-750c-44b2-96c2-1cac1a4ee146      2
3  2c14ec77-44ca-4b3c-b470-96286411c617     14
4  712c3ce9-750a-4cc4-8f94-f8033c31cb2c      0
Pred class dist: [53 50 45 48 50 51 50 46 41 32 55 50 53 50 43 48 42 42]
Final best Val F1: 0.9692, Acc: 0.9657
