In [1]:
import torch
from torch import nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import train_and_test_model
from train_and_test_model import trainModel,testModel
from torchvision.transforms import TrivialAugmentWide,AutoAugment, AutoAugmentPolicy
from ResNet_Classifier import classifierTrain

In [3]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet 通常用224×224
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_dataset_origin = datasets.ImageFolder(root='./datasets/train', transform=transform)
train_loader_origin = DataLoader(train_dataset_origin, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

test_dataset_origin = datasets.ImageFolder(root='./datasets/test', transform=transform)
test_loader_origin = DataLoader(test_dataset_origin, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)

val_dataset_origin = datasets.ImageFolder(root='./datasets/val', transform=transform)
val_loader_origin = DataLoader(val_dataset_origin, batch_size=32, shuffle=False, num_workers=4, pin_memory=True)


## BaseLine: ResNet-50 + SVM Classifier

In [4]:
classifierTrain(50,train_loader_origin,test_loader_origin,'SVM')

ResNet50 SVM Test Accuracy:0.9611111111111111

📊 Classification Report:
              precision    recall  f1-score   support

           0     0.9667    0.9667    0.9667       120
           1     0.9487    0.9250    0.9367       120
           2     0.9915    0.9667    0.9789       120
           3     0.9667    0.9667    0.9667       120
           4     0.9833    0.9833    0.9833       120
           5     0.9600    1.0000    0.9796       120
           6     0.9744    0.9500    0.9620       120
           7     0.9658    0.9417    0.9536       120
           8     0.9339    0.9417    0.9378       120
           9     0.9750    0.9750    0.9750       120
          10     0.9831    0.9667    0.9748       120
          11     0.9590    0.9750    0.9669       120
          12     0.9256    0.9333    0.9295       120
          13     0.9752    0.9833    0.9793       120
          14     0.9113    0.9417    0.9262       120

    accuracy                         0.9611      1800
   macro

## Experienment 1: change Resnet layer

In [5]:
classifierTrain(18,train_loader_origin,test_loader_origin,'SVM')

ResNet18 SVM Test Accuracy:0.9416666666666667

📊 Classification Report:
              precision    recall  f1-score   support

           0     0.9426    0.9583    0.9504       120
           1     0.8661    0.9167    0.8907       120
           2     0.9504    0.9583    0.9544       120
           3     0.9310    0.9000    0.9153       120
           4     1.0000    0.9750    0.9873       120
           5     0.9914    0.9583    0.9746       120
           6     0.9669    0.9750    0.9710       120
           7     0.9328    0.9250    0.9289       120
           8     0.9322    0.9167    0.9244       120
           9     0.9508    0.9667    0.9587       120
          10     0.9915    0.9667    0.9789       120
          11     0.9496    0.9417    0.9456       120
          12     0.8760    0.8833    0.8797       120
          13     0.9667    0.9667    0.9667       120
          14     0.8871    0.9167    0.9016       120

    accuracy                         0.9417      1800
   macro

In [6]:
classifierTrain(101,train_loader_origin,test_loader_origin,'SVM')

ResNet101 SVM Test Accuracy:0.9533333333333334

📊 Classification Report:
              precision    recall  f1-score   support

           0     0.9280    0.9667    0.9469       120
           1     0.9316    0.9083    0.9198       120
           2     0.9915    0.9750    0.9832       120
           3     0.9127    0.9583    0.9350       120
           4     0.9915    0.9750    0.9832       120
           5     0.9754    0.9917    0.9835       120
           6     0.9752    0.9833    0.9793       120
           7     0.9823    0.9250    0.9528       120
           8     0.9322    0.9167    0.9244       120
           9     0.9512    0.9750    0.9630       120
          10     0.9832    0.9750    0.9791       120
          11     0.9748    0.9667    0.9707       120
          12     0.9167    0.9167    0.9167       120
          13     0.9669    0.9750    0.9710       120
          14     0.8917    0.8917    0.8917       120

    accuracy                         0.9533      1800
   macr

## Experienment 2: change Classifier

In [7]:
classifierTrain(50,train_loader_origin,test_loader_origin,'KNN')

ResNet50 KNN Test Accuracy:0.9183333333333333

📊 Classification Report:
              precision    recall  f1-score   support

           0     0.9268    0.9500    0.9383       120
           1     0.8235    0.8167    0.8201       120
           2     0.9583    0.9583    0.9583       120
           3     0.9322    0.9167    0.9244       120
           4     0.9832    0.9750    0.9791       120
           5     0.9141    0.9750    0.9435       120
           6     0.9339    0.9417    0.9378       120
           7     0.9391    0.9000    0.9191       120
           8     0.8571    0.8500    0.8536       120
           9     0.9421    0.9500    0.9461       120
          10     0.9735    0.9167    0.9442       120
          11     0.9407    0.9250    0.9328       120
          12     0.8134    0.9083    0.8583       120
          13     0.9667    0.9667    0.9667       120
          14     0.8839    0.8250    0.8534       120

    accuracy                         0.9183      1800
   macro

In [8]:
classifierTrain(50,train_loader_origin,test_loader_origin,'MLP')

ResNet50 MLP Test Accuracy:0.9483333333333334

📊 Classification Report:
              precision    recall  f1-score   support

           0     0.9587    0.9667    0.9627       120
           1     0.9474    0.9000    0.9231       120
           2     0.9748    0.9667    0.9707       120
           3     0.9661    0.9500    0.9580       120
           4     0.9669    0.9750    0.9710       120
           5     0.9504    0.9583    0.9544       120
           6     0.9426    0.9583    0.9504       120
           7     0.9576    0.9417    0.9496       120
           8     0.9322    0.9167    0.9244       120
           9     0.9500    0.9500    0.9500       120
          10     0.9661    0.9500    0.9580       120
          11     0.9286    0.9750    0.9512       120
          12     0.8898    0.9417    0.9150       120
          13     0.9664    0.9583    0.9623       120
          14     0.9322    0.9167    0.9244       120

    accuracy                         0.9483      1800
   macro

  X_train = torch.tensor(train_features, dtype=torch.float32).to(device)
  y_train = torch.tensor(train_labels, dtype=torch.long).to(device)
  X_test = torch.tensor(test_features, dtype=torch.float32).to(device)


In [9]:
classifierTrain(50,train_loader_origin,test_loader_origin,'Proto')

ResNet50 <ResNet_Classifier.ProtoNetClassifier object at 0x7f162135e960> Test Accuracy:0.8238888888888889

📊 Classification Report:
              precision    recall  f1-score   support

           0     0.9561    0.9083    0.9316       120
           1     0.5985    0.6833    0.6381       120
           2     0.9908    0.9000    0.9432       120
           3     0.8130    0.8333    0.8230       120
           4     1.0000    0.9000    0.9474       120
           5     0.9293    0.7667    0.8402       120
           6     0.7760    0.8083    0.7918       120
           7     0.8099    0.8167    0.8133       120
           8     0.6992    0.7167    0.7078       120
           9     0.7877    0.9583    0.8647       120
          10     0.9412    0.8000    0.8649       120
          11     0.8772    0.8333    0.8547       120
          12     0.7143    0.8333    0.7692       120
          13     0.8537    0.8750    0.8642       120
          14     0.7500    0.7250    0.7373       120

  

## Experienment 3: use fine-tuning

In [10]:
trainModel(50,train_loader_origin,val_loader_origin,'origin')
testModel(50,test_loader_origin,'origin')

  bestweight = torch.load(bestweight_path)
  bestweight = torch.load(bestweight_path)


✅ ResNet-50, Accuracy: 93.17%

📊 Classification Report:
              precision    recall  f1-score   support

           0     0.9008    0.9833    0.9402       120
           1     0.8871    0.9167    0.9016       120
           2     0.9504    0.9583    0.9544       120
           3     0.9818    0.9000    0.9391       120
           4     1.0000    0.9750    0.9873       120
           5     0.9360    0.9750    0.9551       120
           6     0.9720    0.8667    0.9163       120
           7     0.9714    0.8500    0.9067       120
           8     0.8102    0.9250    0.8638       120
           9     0.9355    0.9667    0.9508       120
          10     0.9590    0.9750    0.9669       120
          11     0.9344    0.9500    0.9421       120
          12     0.8880    0.9250    0.9061       120
          13     0.9831    0.9667    0.9748       120
          14     0.9018    0.8417    0.8707       120

    accuracy                         0.9317      1800
   macro avg     0.9341 