In [6]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import os
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import accuracy_score
from torch.optim.lr_scheduler import CyclicLR
import numpy as np

class OnePunchManDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.annotations['class'].unique())}

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.annotations.iloc[idx, 1])
        image = Image.open(img_path).convert("RGB")
        label = self.class_to_idx[self.annotations.iloc[idx, 2]]
        if self.transform:
            image = self.transform(image)
        return image, label

def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()
    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

def cutmix_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).cuda()
    
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    
    y_a, y_b = y, y[index]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, y_a, y_b, lam

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomCrop(224, padding=4),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

csv_file = 'one-punch-man/train.csv'
root_dir = 'one-punch-man'
dataset = OnePunchManDataset(csv_file=csv_file, root_dir=root_dir, transform=train_transform)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

class ImprovedCNN(nn.Module):
    def __init__(self, num_classes):
        super(ImprovedCNN, self).__init__()
        self.efficientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
        num_features = self.efficientnet.classifier[1].in_features
        self.efficientnet.classifier = nn.Identity()
        self.fc = nn.Sequential(
            nn.Linear(num_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.efficientnet(x)
        x = self.fc(x)
        return x

num_classes = len(dataset.class_to_idx)
model = ImprovedCNN(num_classes=num_classes).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)
scheduler = CyclicLR(optimizer, base_lr=0.00001, max_lr=0.0001, step_size_up=10, mode='triangular')

def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, alpha=1.0):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.cuda(), labels.cuda()
        if np.random.rand() > 0.5:
            images, labels_a, labels_b, lam = mixup_data(images, labels, alpha)
            loss_fn = mixup_criterion
        else:
            images, labels_a, labels_b, lam = cutmix_data(images, labels, alpha)
            loss_fn = mixup_criterion

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(criterion, outputs, labels_a, labels_b, lam)
        loss.backward()
        optimizer.step()
        scheduler.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)

def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    val_accuracy = accuracy_score(all_labels, all_preds)
    return val_loss / len(val_loader), val_accuracy

num_epochs = 60
best_val_accuracy = 0.0
patience = 10
counter = 0

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, scheduler, alpha=1.0)
    val_loss, val_accuracy = validate(model, val_loader, criterion)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}')
    
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model.pth')
        counter = 0
    else:
        counter += 1
    
    if counter >= patience:
        print("Early stopping")
        break

model.load_state_dict(torch.load('best_model.pth'))

def predict(model, img_path, transform):
    image = Image.open(img_path).convert("RGB")
    image = transform(image).unsqueeze(0).cuda()
    model.eval()
    with torch.no_grad():
        output = model(image)
    _, predicted = torch.max(output, 1)
    return predicted.item()

test_folder = 'one-punch-man/test/test'
predictions = []

for img_name in os.listdir(test_folder):
    if img_name.endswith(('.png', '.jpg', '.jpeg')):
        img_path = os.path.join(test_folder, img_name)
        prediction = predict(model, img_path, val_transform)
        predicted_class = list(dataset.class_to_idx.keys())[prediction]
        predictions.append({'id': img_name.split('.')[0], 'path': img_name, 'class': predicted_class})

predictions_df = pd.DataFrame(predictions)
predictions_df.to_csv('predictions.csv', index=False)

Epoch 1, Train Loss: 1.7966177940368653, Validation Loss: 1.8067299127578735, Validation Accuracy: 0.16666666666666666
Epoch 2, Train Loss: 1.7927348136901855, Validation Loss: 1.8049814105033875, Validation Accuracy: 0.16666666666666666
Epoch 3, Train Loss: 1.7898041486740113, Validation Loss: 1.8014481663703918, Validation Accuracy: 0.2222222222222222
Epoch 4, Train Loss: 1.7855790853500366, Validation Loss: 1.7997902035713196, Validation Accuracy: 0.25
Epoch 5, Train Loss: 1.7841902017593383, Validation Loss: 1.7968769669532776, Validation Accuracy: 0.25
Epoch 6, Train Loss: 1.7773647546768188, Validation Loss: 1.790298581123352, Validation Accuracy: 0.2777777777777778
Epoch 7, Train Loss: 1.7634865760803222, Validation Loss: 1.7863752245903015, Validation Accuracy: 0.3611111111111111
Epoch 8, Train Loss: 1.7603312253952026, Validation Loss: 1.7824226021766663, Validation Accuracy: 0.3888888888888889
Epoch 9, Train Loss: 1.7761489868164062, Validation Loss: 1.781802237033844, Valida