In [None]:
import os
import glob
import csv
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

# Ensure GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Kaggle dataset path and checkpoint directory
DATASET_PATH = "/kaggle/input/hw1-data/data"
CKPT_PATH = "/kaggle/working/checkpoints"
os.makedirs(CKPT_PATH, exist_ok=True)


In [None]:
class TrainValDataset(Dataset):
    def __init__(self, mode='train'):
        super().__init__()
        self.mode = mode
        self.img_list = glob.glob(f'{DATASET_PATH}/{mode}/*/*.jpg')

        if mode == 'train':
            self.preprocess = transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                    transforms.RandomRotation(10),  
                    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                    transforms.ToTensor(),
                    transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])

        else:
            self.preprocess = transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(224),
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
        
        print(f"=> Loaded {len(self.img_list)} images for {mode}.")

    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, index):
        img_path = self.img_list[index]
        img = Image.open(img_path).convert('RGB')
        label = int(img_path.split('/')[-2])  

        processed_img = self.preprocess(img)


        return processed_img, torch.tensor(label, dtype=torch.long)  


class TestDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.img_list = glob.glob(f'{DATASET_PATH}/test/*.jpg')

        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        print(f"=> Loaded {len(self.img_list)} images for testing.")

        # Define TTA transformations
        self.tta_transforms = [
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            transforms.Compose([
                transforms.Resize(256),
                transforms.RandomHorizontalFlip(p=1.0),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            transforms.Compose([
                transforms.Resize(256),
                transforms.RandomRotation(10),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ]),
            transforms.Compose([
                transforms.Resize(256),
                transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        ]

    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, index):
        img_name = os.path.basename(self.img_list[index]).split('.')[0]
        img = Image.open(self.img_list[index]).convert('RGB')

        # Apply all TTA transforms
        tta_imgs = [transform(img) for transform in self.tta_transforms]
        
        return img_name, torch.stack(tta_imgs)  # Shape: (num_tta, C, H, W)


In [None]:
class ResNet(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.resnet = models.resnet152(weights='DEFAULT')
        # self.resnet = torchvision.models.resnext50_32x4d(pretrained=True) 
        # self.resnet = torchvision.models.resnext50_32x4d(weights=ResNeXt50_32X4D_Weights.DEFAULT) 
        # self.resnet = models.resnext101_64x4d(weights='DEFAULT')
        # self.resnet = torchvision.models.resnext101_64x4d(weights=torchvision.models.ResNeXt101_64X4D_Weights.DEFAULT) 
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Sequential(
            nn.Dropout(p=0.5), 
            nn.Linear(in_features, num_classes)
        )

    def forward(self, img):
        return self.resnet(img)

def load_model(model, optimizer, save_path, device="cuda"):
    print(f"=> Loading checkpoint '{save_path}'...")
    checkpoint = torch.load(save_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    print("=> Checkpoint Loaded.")
    return model, optimizer, epoch

def save_model(model, optimizer, epoch, save_path):
    print(f"=> Saving checkpoint to '{save_path}'...")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, save_path)
    print("=> Checkpoint Saved.")


In [None]:
def get_model_size(model):
    param_num = sum(p.numel() for p in model.parameters()) / 1000000.0
    print(f'#Parameters: {param_num:.2f}M')

In [None]:
def validate(model, dataloader, criterion, best_val_acc=float('-inf')):
    model.to(device)

    print("=> Starting validation...")
    model.eval()

    total_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validation")

        for img, label in pbar:
            img, label = img.to(device), label.to(device)
            pred = model(img)
            loss = criterion(pred, label)

            total_loss += loss.item()
            pbar.set_postfix(loss=loss.item()) 

            _, predicted = torch.max(pred, 1)
            correct += (predicted == label).sum().item()
            total += label.size(0)

    avg_loss = total_loss / len(dataloader)
    acc = 100 * correct / total
    print(f"Validation Avg Loss: {avg_loss}")
    print(f"Validation Avg acc: {acc}")
    print(f"best{best_val_acc} , {acc}")

    # Update best validation loss
    if acc > best_val_acc:
        best_val_acc = acc
        print(f"✅ New best validation acc: {best_val_acc:.6f}")

    return avg_loss, best_val_acc, acc


In [None]:
def evaluate(model, dataloader, result_path='/kaggle/working/prediction.csv'):
    model.to(device)
    print("=> Starting evaluation...")
    model.eval()

    with open(result_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['image_name', 'pred_label'])

        with torch.no_grad():
            pbar = tqdm(dataloader, desc="Evaluation")

            # for img_name, img in pbar:
            #     img = img.to(device)
            #     pred = model(img)
            #     writer.writerow([img_name[0], torch.argmax(pred).item()])
            for img_names, tta_imgs in pbar:
                batch_size, num_tta, C, H, W = tta_imgs.shape
                tta_imgs = tta_imgs.view(-1, C, H, W).to(device)  #

                outputs = model(tta_imgs)  
                outputs = torch.softmax(outputs, dim=1)
                outputs = outputs.view(batch_size, num_tta, -1)
                outputs = outputs.mean(dim=1)  

                preds = outputs.argmax(dim=1).cpu().numpy()

                for name, pred in zip(img_names, preds):
                    writer.writerow([name, pred])


In [None]:
def run_test(ckpt_root = "/kaggle/working/checkpoints"):
    
    mode = "test"
    ckpt_path = f"{ckpt_root}/best_checkpoint.pth"
    print(f"✅ Using checkpoint: {ckpt_path}")

    result_path = "/kaggle/working/prediction.csv"

    # Load Dataset
    dataset = TestDataset()
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    # Initialize Model
    model = ResNet(num_classes=100).to(device)

    # Load Checkpoint
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=5e-4)
    # optimizer = optim.AdamW(model.parameters(), lr=1e-4)
    model, optimizer, _ = load_model(model, optimizer, ckpt_path, device)
    get_model_size(model)

    # Evaluate Model
    evaluate(model, dataloader, result_path)


In [None]:
def mixup_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(pred, y_a, y_b, lam, criterion):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [None]:
import pandas as pd

def save_trainingdata():

    df = pd.DataFrame({
        'epoch': list(range(start_epoch+1, start_epoch+len(train_losses) + 1)),
        'train_loss': train_losses,
        'val_loss': val_losses,
        'train_acc': train_accs,
        'val_acc': val_accs
    })

    # 儲存 CSV
    csv_path = "/kaggle/working/training_results.csv"
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")


    plt.figure(figsize=(8, 6))
    plt.plot(range(start_epoch+1, start_epoch+len(train_losses) + 1), train_losses, label="Train Loss", marker="o")
    plt.plot(range(start_epoch+1, start_epoch+len(train_losses) + 1), val_losses, label="Validation Loss", marker="s")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training and Validation Loss")
    plt.legend()
    plt.grid(True)
    
    plt.savefig("/kaggle/working/loss_plot.png")
    
    plt.figure(figsize=(8, 6))
    plt.plot(range(start_epoch+1, start_epoch+len(train_losses) + 1), train_accs, label="Train Acc", marker="o")
    plt.plot(range(start_epoch+1, start_epoch+len(train_losses) + 1), val_accs, label="Validation Acc", marker="s")
    plt.xlabel("Epoch")
    plt.ylabel("Acc")
    plt.title("Training and Validation Acc")
    plt.legend()
    plt.grid(True)
    
    plt.savefig("/kaggle/working/acc_plot.png")


In [None]:
mode = "train"
ckpt_root = "/kaggle/working/checkpoints"
save_per_epoch = 1
batch_size = 64
num_epochs = 100
alpha_mixup = 0.2
patience = 20
resume = True 

# Load Dataset
dataset = TrainValDataset(mode)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

val_dataset = TrainValDataset("val")
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

train_losses = []
train_accs = []
val_losses = []
val_accs = []
# Initialize Model
model = ResNet(num_classes=100).to(device)

# criterion = nn.MSELoss()
# optimizer = optim.AdamW(model.parameters(), lr=1e-4)
# criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
# scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(dataloader), epochs=num_epochs,
                                          # pct_start=0.1, anneal_strategy='cos', div_factor=25, final_div_factor=1e4)
start_epoch = 0
early_stop_count = 0
best_val_acc = float('-inf')

if resume:
    ckpt_path = f'/kaggle/input/dl-hw1-model/last_checkpoint.pth'
    model, optimizer, start_epoch = load_model(model, optimizer, ckpt_path, device)
    print(f"load_model start epoch = {start_epoch}")

best_epoch  = start_epoch
print("=> Starting training...")
model.train()

for epoch in range(start_epoch, num_epochs):
    epoch_loss, correct, total = 0.0, 0, 0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for img, label in pbar:
        img, label = img.to(device), label.to(device)
        optimizer.zero_grad()

        # Apply MixUp
        mixed_img, y_a, y_b, lam = mixup_data(img, label, alpha=alpha_mixup)
        pred = model(mixed_img)
        loss = mixup_criterion(pred, y_a, y_b, lam, criterion)


        # pred = model(img)
        # loss = criterion(pred, label)

        loss.backward()
        optimizer.step()
        # scheduler.step()
        epoch_loss += loss.item()
        pbar.set_postfix(loss=loss.item())        

        _, predicted = torch.max(pred, 1) 
        # correct += (predicted == label).sum().item()
        correct += (lam * (predicted == y_a).sum().item() + (1 - lam) * (predicted == y_b).sum().item())
        total += label.size(0)



    avg_train_acc = 100 * correct / total
    avg_train_loss = epoch_loss / len(dataloader)
    train_losses.append(avg_train_loss)
    train_accs.append(avg_train_acc)
    print(f"Epoch {epoch+1} - Avg Train Loss: {avg_train_loss}")
    print(f"Epoch {epoch+1} - Avg Train Acc:  {avg_train_acc}")


    # Validate after each epoch
    val_loss, best_val_acc, val_acc = validate(model, val_dataloader,criterion , best_val_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)

    scheduler.step()
    # scheduler.step(val_acc)
    # scheduler.step(val_loss)
    save_model(model, optimizer, epoch + 1, f'{ckpt_root}/last_checkpoint.pth')

    # Save model if val_loss is the best
    if val_acc == best_val_acc:
        save_model(model, optimizer, epoch + 1, f'{ckpt_root}/best_checkpoint.pth')
        early_stop_count = 0
        # save_model(model, optimizer, epoch + 1, f'/kaggle/working/best_checkpoint.pth')
    else:
        early_stop_count += 1
    save_trainingdata()

    ### early stop:
    if early_stop_count > patience:
        print(f"Early stop at epoch{epoch}.\n It is not progress from epoch {epoch - patience}.")
        break

# Automatically test with the best checkpoint
print("🎯 Training complete! Running test with the best checkpoint...")
run_test()


