In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.models import resnet152
from tqdm import tqdm
import time
import copy
import os
import csv
from torch.utils.data import Dataset
from PIL import Image

In [2]:
# Channel Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

# Spatial Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        padding = (kernel_size - 1) // 2
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

# CBAM Module
class CBAM(nn.Module):
    def __init__(self, planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(planes, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = x * self.channel_attention(x)
        out = out * self.spatial_attention(out)
        return out

In [3]:
class ResNet152_CBAM(nn.Module):
    def __init__(self, pretrained=True, num_classes=2):
        super(ResNet152_CBAM, self).__init__()
        self.model = resnet152(pretrained=pretrained)

        # CBAM modules
        self.cbam1 = CBAM(256)   # After layer1
        self.cbam2 = CBAM(512)   # After layer2
        self.cbam3 = CBAM(1024)  # After layer3
        self.cbam4 = CBAM(2048)  # After layer4

        # Modify the classifier
        self.model.fc = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.cbam1(x)

        x = self.model.layer2(x)
        x = self.cbam2(x)

        x = self.model.layer3(x)
        x = self.cbam3(x)

        x = self.model.layer4(x)
        x = self.cbam4(x)

        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.model.fc(x)

        return x

In [4]:
data_dir = '/home/gpl/文件/Kezia/Visual Recognition/HW1/hw1-data/data'
batch_size = 16
num_workers = 4
num_classes = 100

# Data augmentations and normalization
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
}

# Datasets and Dataloaders
image_datasets = {x: datasets.ImageFolder(root=f"{data_dir}/{x}",
                                          transform=data_transforms[x])
                  for x in ['train', 'val']}

dataloaders = {x: DataLoader(image_datasets[x], batch_size=batch_size,
                             shuffle=True, num_workers=num_workers)
               for x in ['train', 'val']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            loop = tqdm(dataloaders[phase], total=len(dataloaders[phase]),
                        desc=f"{phase.capitalize()} Epoch {epoch+1}")

            for inputs, labels in loop:
                inputs = inputs.to(device)
                labels = labels.to(device)


                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                # print(labels.data)

                loop.set_postfix(loss=running_loss /
                                 ((loop.n + 1) * batch_size),
                                 acc=running_corrects.double().item() /
                                 ((loop.n + 1) * batch_size))

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        scheduler.step()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m\
          {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    return model


In [6]:
# Initialize CBAM-ResNet152
model_cbam = ResNet152_CBAM(pretrained=True, num_classes=num_classes)
model_cbam = model_cbam.to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer (finetune all parameters)
optimizer = optim.Adam(model_cbam.parameters(), lr=0.0001)

# Learning rate scheduler
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)




In [7]:
num_epochs = 100
model_cbam = train_model(model_cbam, criterion, optimizer, exp_lr_scheduler,
                         num_epochs=num_epochs)

# Save the best model
torch.save(model_cbam.state_dict(), 'resnet152_cbam_best100.pth')
print('Model saved as resnet152_cbam_best100.pth')

Epoch 1/100
----------


Train Epoch 1: 100%|██████████| 1296/1296 [08:52<00:00,  2.43it/s, acc=0.505, loss=2.06]


train Loss: 2.0638 Acc: 0.5052


Val Epoch 1: 100%|██████████| 19/19 [00:08<00:00,  2.34it/s, acc=0.628, loss=1.56]


val Loss: 1.5010 Acc: 0.6033
Epoch 2/100
----------


Train Epoch 2: 100%|██████████| 1296/1296 [06:30<00:00,  3.32it/s, acc=0.681, loss=1.21]


train Loss: 1.2129 Acc: 0.6816


Val Epoch 2: 100%|██████████| 19/19 [00:02<00:00,  9.38it/s, acc=0.705, loss=1.36]


val Loss: 1.3035 Acc: 0.6767
Epoch 3/100
----------


Train Epoch 3: 100%|██████████| 1296/1296 [06:49<00:00,  3.16it/s, acc=0.734, loss=1]   


train Loss: 1.0031 Acc: 0.7344


Val Epoch 3: 100%|██████████| 19/19 [00:02<00:00,  8.38it/s, acc=0.701, loss=1.11]


val Loss: 1.1250 Acc: 0.7100
Epoch 4/100
----------


Train Epoch 4: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.758, loss=0.911]


train Loss: 0.9118 Acc: 0.7588


Val Epoch 4: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.747, loss=0.899]


val Loss: 0.9114 Acc: 0.7567
Epoch 5/100
----------


Train Epoch 5: 100%|██████████| 1296/1296 [06:36<00:00,  3.27it/s, acc=0.776, loss=0.825]


train Loss: 0.8252 Acc: 0.7762


Val Epoch 5: 100%|██████████| 19/19 [00:02<00:00,  9.35it/s, acc=0.809, loss=0.89] 


val Loss: 0.8547 Acc: 0.7767
Epoch 6/100
----------


Train Epoch 6: 100%|██████████| 1296/1296 [06:29<00:00,  3.32it/s, acc=0.795, loss=0.768]


train Loss: 0.7684 Acc: 0.7952


Val Epoch 6: 100%|██████████| 19/19 [00:02<00:00,  9.27it/s, acc=0.774, loss=0.946]


val Loss: 0.9080 Acc: 0.7433
Epoch 7/100
----------


Train Epoch 7: 100%|██████████| 1296/1296 [06:42<00:00,  3.22it/s, acc=0.8, loss=0.735]  


train Loss: 0.7350 Acc: 0.8007


Val Epoch 7: 100%|██████████| 19/19 [00:02<00:00,  9.00it/s, acc=0.781, loss=0.978]


val Loss: 0.9389 Acc: 0.7500
Epoch 8/100
----------


Train Epoch 8: 100%|██████████| 1296/1296 [06:48<00:00,  3.18it/s, acc=0.858, loss=0.524]


train Loss: 0.5239 Acc: 0.8583


Val Epoch 8: 100%|██████████| 19/19 [00:02<00:00,  8.86it/s, acc=0.829, loss=0.672]


val Loss: 0.6808 Acc: 0.8400
Epoch 9/100
----------


Train Epoch 9: 100%|██████████| 1296/1296 [06:47<00:00,  3.18it/s, acc=0.88, loss=0.438] 


train Loss: 0.4385 Acc: 0.8805


Val Epoch 9: 100%|██████████| 19/19 [00:02<00:00,  9.14it/s, acc=0.822, loss=0.639]


val Loss: 0.6480 Acc: 0.8333
Epoch 10/100
----------


Train Epoch 10: 100%|██████████| 1296/1296 [07:10<00:00,  3.01it/s, acc=0.887, loss=0.413]


train Loss: 0.4136 Acc: 0.8879


Val Epoch 10: 100%|██████████| 19/19 [00:02<00:00,  8.37it/s, acc=0.832, loss=0.623]


val Loss: 0.6314 Acc: 0.8433
Epoch 11/100
----------


Train Epoch 11: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.893, loss=0.39] 


train Loss: 0.3901 Acc: 0.8934


Val Epoch 11: 100%|██████████| 19/19 [00:02<00:00,  8.36it/s, acc=0.832, loss=0.603]


val Loss: 0.6110 Acc: 0.8433
Epoch 12/100
----------


Train Epoch 12: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.895, loss=0.384]


train Loss: 0.3839 Acc: 0.8951


Val Epoch 12: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.816, loss=0.575]


val Loss: 0.5823 Acc: 0.8267
Epoch 13/100
----------


Train Epoch 13: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.901, loss=0.363]


train Loss: 0.3634 Acc: 0.9013


Val Epoch 13: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.842, loss=0.588]


val Loss: 0.5960 Acc: 0.8533
Epoch 14/100
----------


Train Epoch 14: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.905, loss=0.344]


train Loss: 0.3440 Acc: 0.9052


Val Epoch 14: 100%|██████████| 19/19 [00:01<00:00,  9.57it/s, acc=0.882, loss=0.632]


val Loss: 0.6066 Acc: 0.8467
Epoch 15/100
----------


Train Epoch 15: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.914, loss=0.318]


train Loss: 0.3180 Acc: 0.9150


Val Epoch 15: 100%|██████████| 19/19 [00:02<00:00,  8.35it/s, acc=0.852, loss=0.601]


val Loss: 0.6091 Acc: 0.8633
Epoch 16/100
----------


Train Epoch 16: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.909, loss=0.334]


train Loss: 0.3340 Acc: 0.9095


Val Epoch 16: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.849, loss=0.573]


val Loss: 0.5804 Acc: 0.8600
Epoch 17/100
----------


Train Epoch 17: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.912, loss=0.325]


train Loss: 0.3250 Acc: 0.9127


Val Epoch 17: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.845, loss=0.589]


val Loss: 0.5967 Acc: 0.8567
Epoch 18/100
----------


Train Epoch 18: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.912, loss=0.321]


train Loss: 0.3210 Acc: 0.9128


Val Epoch 18: 100%|██████████| 19/19 [00:02<00:00,  8.42it/s, acc=0.845, loss=0.582]


val Loss: 0.5901 Acc: 0.8567
Epoch 19/100
----------


Train Epoch 19: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.914, loss=0.317]


train Loss: 0.3176 Acc: 0.9147


Val Epoch 19: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.852, loss=0.577]


val Loss: 0.5845 Acc: 0.8633
Epoch 20/100
----------


Train Epoch 20: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.914, loss=0.311]


train Loss: 0.3108 Acc: 0.9145


Val Epoch 20: 100%|██████████| 19/19 [00:02<00:00,  8.32it/s, acc=0.849, loss=0.581]


val Loss: 0.5888 Acc: 0.8600
Epoch 21/100
----------


Train Epoch 21: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.918, loss=0.302]


train Loss: 0.3026 Acc: 0.9181


Val Epoch 21: 100%|██████████| 19/19 [00:02<00:00,  8.38it/s, acc=0.836, loss=0.599]


val Loss: 0.6066 Acc: 0.8467
Epoch 22/100
----------


Train Epoch 22: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.917, loss=0.303]


train Loss: 0.3035 Acc: 0.9179


Val Epoch 22: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.819, loss=0.71] 


val Loss: 0.7200 Acc: 0.8300
Epoch 23/100
----------


Train Epoch 23: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.918, loss=0.306]


train Loss: 0.3058 Acc: 0.9186


Val Epoch 23: 100%|██████████| 19/19 [00:02<00:00,  8.37it/s, acc=0.855, loss=0.568]


val Loss: 0.5753 Acc: 0.8667
Epoch 24/100
----------


Train Epoch 24: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.914, loss=0.31] 


train Loss: 0.3102 Acc: 0.9150


Val Epoch 24: 100%|██████████| 19/19 [00:02<00:00,  8.35it/s, acc=0.839, loss=0.572]


val Loss: 0.5792 Acc: 0.8500
Epoch 25/100
----------


Train Epoch 25: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.916, loss=0.308]


train Loss: 0.3081 Acc: 0.9163


Val Epoch 25: 100%|██████████| 19/19 [00:02<00:00,  8.59it/s, acc=0.896, loss=0.6]  


val Loss: 0.5762 Acc: 0.8600
Epoch 26/100
----------


Train Epoch 26: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.914, loss=0.312]


train Loss: 0.3123 Acc: 0.9141


Val Epoch 26: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.852, loss=0.567]


val Loss: 0.5744 Acc: 0.8633
Epoch 27/100
----------


Train Epoch 27: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.31] 


train Loss: 0.3097 Acc: 0.9159


Val Epoch 27: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.852, loss=0.577]


val Loss: 0.5848 Acc: 0.8633
Epoch 28/100
----------


Train Epoch 28: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.309]


train Loss: 0.3089 Acc: 0.9153


Val Epoch 28: 100%|██████████| 19/19 [00:02<00:00,  8.34it/s, acc=0.836, loss=0.568]


val Loss: 0.5756 Acc: 0.8467
Epoch 29/100
----------


Train Epoch 29: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.915, loss=0.313]


train Loss: 0.3131 Acc: 0.9159


Val Epoch 29: 100%|██████████| 19/19 [00:02<00:00,  8.61it/s, acc=0.839, loss=0.579]


val Loss: 0.5868 Acc: 0.8500
Epoch 30/100
----------


Train Epoch 30: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.917, loss=0.304]


train Loss: 0.3037 Acc: 0.9171


Val Epoch 30: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.859, loss=0.573]


val Loss: 0.5808 Acc: 0.8700
Epoch 31/100
----------


Train Epoch 31: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.918, loss=0.309]


train Loss: 0.3087 Acc: 0.9183


Val Epoch 31: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.836, loss=0.567]


val Loss: 0.5749 Acc: 0.8467
Epoch 32/100
----------


Train Epoch 32: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.918, loss=0.305]


train Loss: 0.3050 Acc: 0.9184


Val Epoch 32: 100%|██████████| 19/19 [00:02<00:00,  8.43it/s, acc=0.839, loss=0.58] 


val Loss: 0.5878 Acc: 0.8500
Epoch 33/100
----------


Train Epoch 33: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.916, loss=0.309]


train Loss: 0.3093 Acc: 0.9163


Val Epoch 33: 100%|██████████| 19/19 [00:02<00:00,  8.38it/s, acc=0.862, loss=0.566]


val Loss: 0.5734 Acc: 0.8733
Epoch 34/100
----------


Train Epoch 34: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.917, loss=0.309]


train Loss: 0.3092 Acc: 0.9173


Val Epoch 34: 100%|██████████| 19/19 [00:02<00:00,  8.42it/s, acc=0.852, loss=0.566]


val Loss: 0.5733 Acc: 0.8633
Epoch 35/100
----------


Train Epoch 35: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.916, loss=0.306]


train Loss: 0.3062 Acc: 0.9168


Val Epoch 35: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.852, loss=0.551]


val Loss: 0.5585 Acc: 0.8633
Epoch 36/100
----------


Train Epoch 36: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.916, loss=0.311]


train Loss: 0.3107 Acc: 0.9164


Val Epoch 36: 100%|██████████| 19/19 [00:02<00:00,  7.92it/s, acc=0.855, loss=0.57] 


val Loss: 0.5771 Acc: 0.8667
Epoch 37/100
----------


Train Epoch 37: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.916, loss=0.307]


train Loss: 0.3077 Acc: 0.9165


Val Epoch 37: 100%|██████████| 19/19 [00:02<00:00,  8.45it/s, acc=0.852, loss=0.578]


val Loss: 0.5857 Acc: 0.8633
Epoch 38/100
----------


Train Epoch 38: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.917, loss=0.309]


train Loss: 0.3088 Acc: 0.9173


Val Epoch 38: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.855, loss=0.577]


val Loss: 0.5844 Acc: 0.8667
Epoch 39/100
----------


Train Epoch 39: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.918, loss=0.306]


train Loss: 0.3058 Acc: 0.9189


Val Epoch 39: 100%|██████████| 19/19 [00:02<00:00,  8.38it/s, acc=0.849, loss=0.583]


val Loss: 0.5911 Acc: 0.8600
Epoch 40/100
----------


Train Epoch 40: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.918, loss=0.303]


train Loss: 0.3032 Acc: 0.9183


Val Epoch 40: 100%|██████████| 19/19 [00:01<00:00,  9.66it/s, acc=0.882, loss=0.606]


val Loss: 0.5815 Acc: 0.8467
Epoch 41/100
----------


Train Epoch 41: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.313]


train Loss: 0.3133 Acc: 0.9153


Val Epoch 41: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.845, loss=0.57] 


val Loss: 0.5780 Acc: 0.8567
Epoch 42/100
----------


Train Epoch 42: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.312]


train Loss: 0.3124 Acc: 0.9154


Val Epoch 42: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.849, loss=0.559]


val Loss: 0.5661 Acc: 0.8600
Epoch 43/100
----------


Train Epoch 43: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.916, loss=0.308]


train Loss: 0.3078 Acc: 0.9163


Val Epoch 43: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.845, loss=0.567]


val Loss: 0.5747 Acc: 0.8567
Epoch 44/100
----------


Train Epoch 44: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.916, loss=0.31] 


train Loss: 0.3101 Acc: 0.9161


Val Epoch 44: 100%|██████████| 19/19 [00:02<00:00,  8.43it/s, acc=0.852, loss=0.579]


val Loss: 0.5865 Acc: 0.8633
Epoch 45/100
----------


Train Epoch 45: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.918, loss=0.3]  


train Loss: 0.2998 Acc: 0.9182


Val Epoch 45: 100%|██████████| 19/19 [00:02<00:00,  8.43it/s, acc=0.849, loss=0.575]


val Loss: 0.5831 Acc: 0.8600
Epoch 46/100
----------


Train Epoch 46: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.914, loss=0.315]


train Loss: 0.3153 Acc: 0.9150


Val Epoch 46: 100%|██████████| 19/19 [00:02<00:00,  8.44it/s, acc=0.832, loss=0.593]


val Loss: 0.6014 Acc: 0.8433
Epoch 47/100
----------


Train Epoch 47: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.913, loss=0.315]


train Loss: 0.3148 Acc: 0.9131


Val Epoch 47: 100%|██████████| 19/19 [00:02<00:00,  8.00it/s, acc=0.839, loss=0.583]


val Loss: 0.5910 Acc: 0.8500
Epoch 48/100
----------


Train Epoch 48: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.916, loss=0.305]


train Loss: 0.3056 Acc: 0.9161


Val Epoch 48: 100%|██████████| 19/19 [00:02<00:00,  8.37it/s, acc=0.849, loss=0.584]


val Loss: 0.5919 Acc: 0.8600
Epoch 49/100
----------


Train Epoch 49: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.913, loss=0.316]


train Loss: 0.3161 Acc: 0.9131


Val Epoch 49: 100%|██████████| 19/19 [00:02<00:00,  8.34it/s, acc=0.836, loss=0.584]


val Loss: 0.5922 Acc: 0.8467
Epoch 50/100
----------


Train Epoch 50: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.915, loss=0.31] 


train Loss: 0.3100 Acc: 0.9159


Val Epoch 50: 100%|██████████| 19/19 [00:02<00:00,  8.44it/s, acc=0.852, loss=0.568]


val Loss: 0.5752 Acc: 0.8633
Epoch 51/100
----------


Train Epoch 51: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.92, loss=0.299] 


train Loss: 0.2991 Acc: 0.9209


Val Epoch 51: 100%|██████████| 19/19 [00:01<00:00,  9.66it/s, acc=0.899, loss=0.611]


val Loss: 0.5867 Acc: 0.8633
Epoch 52/100
----------


Train Epoch 52: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.918, loss=0.302]


train Loss: 0.3026 Acc: 0.9188


Val Epoch 52: 100%|██████████| 19/19 [00:02<00:00,  8.38it/s, acc=0.836, loss=0.578]


val Loss: 0.5858 Acc: 0.8467
Epoch 53/100
----------


Train Epoch 53: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.917, loss=0.308]


train Loss: 0.3079 Acc: 0.9178


Val Epoch 53: 100%|██████████| 19/19 [00:02<00:00,  8.38it/s, acc=0.852, loss=0.575]


val Loss: 0.5826 Acc: 0.8633
Epoch 54/100
----------


Train Epoch 54: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.917, loss=0.314]


train Loss: 0.3138 Acc: 0.9175


Val Epoch 54: 100%|██████████| 19/19 [00:02<00:00,  8.42it/s, acc=0.819, loss=0.654]


val Loss: 0.6627 Acc: 0.8300
Epoch 55/100
----------


Train Epoch 55: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.918, loss=0.303]


train Loss: 0.3035 Acc: 0.9182


Val Epoch 55: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.839, loss=0.582]


val Loss: 0.5903 Acc: 0.8500
Epoch 56/100
----------


Train Epoch 56: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.313]


train Loss: 0.3129 Acc: 0.9156


Val Epoch 56: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.842, loss=0.577]


val Loss: 0.5852 Acc: 0.8533
Epoch 57/100
----------


Train Epoch 57: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.92, loss=0.292] 


train Loss: 0.2919 Acc: 0.9210


Val Epoch 57: 100%|██████████| 19/19 [00:02<00:00,  8.38it/s, acc=0.842, loss=0.581]


val Loss: 0.5887 Acc: 0.8533
Epoch 58/100
----------


Train Epoch 58: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.918, loss=0.303]


train Loss: 0.3036 Acc: 0.9183


Val Epoch 58: 100%|██████████| 19/19 [00:02<00:00,  8.43it/s, acc=0.845, loss=0.58] 


val Loss: 0.5876 Acc: 0.8567
Epoch 59/100
----------


Train Epoch 59: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.917, loss=0.306]


train Loss: 0.3065 Acc: 0.9173


Val Epoch 59: 100%|██████████| 19/19 [00:02<00:00,  8.37it/s, acc=0.849, loss=0.572]


val Loss: 0.5800 Acc: 0.8600
Epoch 60/100
----------


Train Epoch 60: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.916, loss=0.309]


train Loss: 0.3087 Acc: 0.9165


Val Epoch 60: 100%|██████████| 19/19 [00:02<00:00,  8.46it/s, acc=0.845, loss=0.575]


val Loss: 0.5830 Acc: 0.8567
Epoch 61/100
----------


Train Epoch 61: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.914, loss=0.312]


train Loss: 0.3123 Acc: 0.9149


Val Epoch 61: 100%|██████████| 19/19 [00:02<00:00,  8.42it/s, acc=0.845, loss=0.575]


val Loss: 0.5826 Acc: 0.8567
Epoch 62/100
----------


Train Epoch 62: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.914, loss=0.316]


train Loss: 0.3162 Acc: 0.9141


Val Epoch 62: 100%|██████████| 19/19 [00:02<00:00,  8.37it/s, acc=0.849, loss=0.568]


val Loss: 0.5756 Acc: 0.8600
Epoch 63/100
----------


Train Epoch 63: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.916, loss=0.308]


train Loss: 0.3085 Acc: 0.9169


Val Epoch 63: 100%|██████████| 19/19 [00:02<00:00,  8.36it/s, acc=0.849, loss=0.567]


val Loss: 0.5746 Acc: 0.8600
Epoch 64/100
----------


Train Epoch 64: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.309]


train Loss: 0.3093 Acc: 0.9157


Val Epoch 64: 100%|██████████| 19/19 [00:02<00:00,  8.48it/s, acc=0.839, loss=0.581]


val Loss: 0.5891 Acc: 0.8500
Epoch 65/100
----------


Train Epoch 65: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.917, loss=0.306]


train Loss: 0.3060 Acc: 0.9180


Val Epoch 65: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.849, loss=0.58] 


val Loss: 0.5874 Acc: 0.8600
Epoch 66/100
----------


Train Epoch 66: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.915, loss=0.309]


train Loss: 0.3092 Acc: 0.9157


Val Epoch 66: 100%|██████████| 19/19 [00:02<00:00,  9.21it/s, acc=0.842, loss=0.581]


val Loss: 0.5891 Acc: 0.8533
Epoch 67/100
----------


Train Epoch 67: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.917, loss=0.307]


train Loss: 0.3068 Acc: 0.9177


Val Epoch 67: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.855, loss=0.586]


val Loss: 0.5936 Acc: 0.8667
Epoch 68/100
----------


Train Epoch 68: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.306]


train Loss: 0.3057 Acc: 0.9154


Val Epoch 68: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.842, loss=0.584]


val Loss: 0.5919 Acc: 0.8533
Epoch 69/100
----------


Train Epoch 69: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.313]


train Loss: 0.3127 Acc: 0.9155


Val Epoch 69: 100%|██████████| 19/19 [00:02<00:00,  8.36it/s, acc=0.842, loss=0.574]


val Loss: 0.5820 Acc: 0.8533
Epoch 70/100
----------


Train Epoch 70: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.917, loss=0.303]


train Loss: 0.3031 Acc: 0.9180


Val Epoch 70: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.842, loss=0.585]


val Loss: 0.5925 Acc: 0.8533
Epoch 71/100
----------


Train Epoch 71: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.92, loss=0.297] 


train Loss: 0.2976 Acc: 0.9207


Val Epoch 71: 100%|██████████| 19/19 [00:02<00:00,  8.42it/s, acc=0.852, loss=0.575]


val Loss: 0.5822 Acc: 0.8633
Epoch 72/100
----------


Train Epoch 72: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.914, loss=0.31] 


train Loss: 0.3101 Acc: 0.9149


Val Epoch 72: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.852, loss=0.6]  


val Loss: 0.6078 Acc: 0.8633
Epoch 73/100
----------


Train Epoch 73: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.915, loss=0.311]


train Loss: 0.3115 Acc: 0.9155


Val Epoch 73: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.842, loss=0.576]


val Loss: 0.5833 Acc: 0.8533
Epoch 74/100
----------


Train Epoch 74: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.915, loss=0.316]


train Loss: 0.3161 Acc: 0.9151


Val Epoch 74: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.859, loss=0.58] 


val Loss: 0.5876 Acc: 0.8700
Epoch 75/100
----------


Train Epoch 75: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.305]


train Loss: 0.3049 Acc: 0.9160


Val Epoch 75: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.842, loss=0.581]


val Loss: 0.5885 Acc: 0.8533
Epoch 76/100
----------


Train Epoch 76: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.916, loss=0.317]


train Loss: 0.3170 Acc: 0.9164


Val Epoch 76: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.839, loss=0.59] 


val Loss: 0.5982 Acc: 0.8500
Epoch 77/100
----------


Train Epoch 77: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.915, loss=0.316]


train Loss: 0.3165 Acc: 0.9152


Val Epoch 77: 100%|██████████| 19/19 [00:01<00:00,  9.52it/s, acc=0.849, loss=0.577]


val Loss: 0.5843 Acc: 0.8600
Epoch 78/100
----------


Train Epoch 78: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.916, loss=0.3]  


train Loss: 0.3004 Acc: 0.9167


Val Epoch 78: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.849, loss=0.591]


val Loss: 0.5991 Acc: 0.8600
Epoch 79/100
----------


Train Epoch 79: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.914, loss=0.314]


train Loss: 0.3142 Acc: 0.9146


Val Epoch 79: 100%|██████████| 19/19 [00:02<00:00,  8.36it/s, acc=0.845, loss=0.57] 


val Loss: 0.5774 Acc: 0.8567
Epoch 80/100
----------


Train Epoch 80: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.917, loss=0.308]


train Loss: 0.3083 Acc: 0.9175


Val Epoch 80: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.855, loss=0.578]


val Loss: 0.5857 Acc: 0.8667
Epoch 81/100
----------


Train Epoch 81: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.917, loss=0.31] 


train Loss: 0.3106 Acc: 0.9174


Val Epoch 81: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.842, loss=0.57] 


val Loss: 0.5772 Acc: 0.8533
Epoch 82/100
----------


Train Epoch 82: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.915, loss=0.309]


train Loss: 0.3089 Acc: 0.9153


Val Epoch 82: 100%|██████████| 19/19 [00:02<00:00,  8.42it/s, acc=0.852, loss=0.573]


val Loss: 0.5807 Acc: 0.8633
Epoch 83/100
----------


Train Epoch 83: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.914, loss=0.315]


train Loss: 0.3149 Acc: 0.9142


Val Epoch 83: 100%|██████████| 19/19 [00:02<00:00,  8.43it/s, acc=0.849, loss=0.586]


val Loss: 0.5938 Acc: 0.8600
Epoch 84/100
----------


Train Epoch 84: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.918, loss=0.3]  


train Loss: 0.3002 Acc: 0.9184


Val Epoch 84: 100%|██████████| 19/19 [00:02<00:00,  8.36it/s, acc=0.852, loss=0.583]


val Loss: 0.5910 Acc: 0.8633
Epoch 85/100
----------


Train Epoch 85: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.915, loss=0.307]


train Loss: 0.3072 Acc: 0.9158


Val Epoch 85: 100%|██████████| 19/19 [00:02<00:00,  8.37it/s, acc=0.845, loss=0.563]


val Loss: 0.5708 Acc: 0.8567
Epoch 86/100
----------


Train Epoch 86: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.917, loss=0.306]


train Loss: 0.3060 Acc: 0.9171


Val Epoch 86: 100%|██████████| 19/19 [00:02<00:00,  8.34it/s, acc=0.845, loss=0.591]


val Loss: 0.5985 Acc: 0.8567
Epoch 87/100
----------


Train Epoch 87: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.914, loss=0.312]


train Loss: 0.3125 Acc: 0.9142


Val Epoch 87: 100%|██████████| 19/19 [00:02<00:00,  8.41it/s, acc=0.849, loss=0.582]


val Loss: 0.5901 Acc: 0.8600
Epoch 88/100
----------


Train Epoch 88: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.916, loss=0.305]


train Loss: 0.3048 Acc: 0.9166


Val Epoch 88: 100%|██████████| 19/19 [00:02<00:00,  8.44it/s, acc=0.845, loss=0.591]


val Loss: 0.5988 Acc: 0.8567
Epoch 89/100
----------


Train Epoch 89: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.913, loss=0.318]


train Loss: 0.3179 Acc: 0.9133


Val Epoch 89: 100%|██████████| 19/19 [00:02<00:00,  8.40it/s, acc=0.845, loss=0.577]


val Loss: 0.5851 Acc: 0.8567
Epoch 90/100
----------


Train Epoch 90: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.914, loss=0.315]


train Loss: 0.3149 Acc: 0.9142


Val Epoch 90: 100%|██████████| 19/19 [00:02<00:00,  8.36it/s, acc=0.855, loss=0.582]


val Loss: 0.5893 Acc: 0.8667
Epoch 91/100
----------


Train Epoch 91: 100%|██████████| 1296/1296 [07:14<00:00,  2.98it/s, acc=0.916, loss=0.307]


train Loss: 0.3076 Acc: 0.9165


Val Epoch 91: 100%|██████████| 19/19 [00:02<00:00,  8.42it/s, acc=0.845, loss=0.556]


val Loss: 0.5637 Acc: 0.8567
Epoch 92/100
----------


Train Epoch 92: 100%|██████████| 1296/1296 [07:13<00:00,  2.99it/s, acc=0.916, loss=0.305]


train Loss: 0.3052 Acc: 0.9163


Val Epoch 92: 100%|██████████| 19/19 [00:02<00:00,  9.12it/s, acc=0.859, loss=0.582]


val Loss: 0.5901 Acc: 0.8700
Epoch 93/100
----------


Train Epoch 93: 100%|██████████| 1296/1296 [07:14<00:00,  2.99it/s, acc=0.914, loss=0.316]


train Loss: 0.3165 Acc: 0.9144


Val Epoch 93: 100%|██████████| 19/19 [00:02<00:00,  9.17it/s, acc=0.899, loss=0.613]


val Loss: 0.5885 Acc: 0.8633
Epoch 94/100
----------


Train Epoch 94: 100%|██████████| 1296/1296 [06:47<00:00,  3.18it/s, acc=0.916, loss=0.308]


train Loss: 0.3079 Acc: 0.9163


Val Epoch 94: 100%|██████████| 19/19 [00:02<00:00,  8.94it/s, acc=0.845, loss=0.553]


val Loss: 0.5599 Acc: 0.8567
Epoch 95/100
----------


Train Epoch 95: 100%|██████████| 1296/1296 [06:46<00:00,  3.19it/s, acc=0.917, loss=0.303]


train Loss: 0.3036 Acc: 0.9173


Val Epoch 95: 100%|██████████| 19/19 [00:02<00:00,  9.10it/s, acc=0.889, loss=0.628]


val Loss: 0.6031 Acc: 0.8533
Epoch 96/100
----------


Train Epoch 96: 100%|██████████| 1296/1296 [06:27<00:00,  3.35it/s, acc=0.915, loss=0.311]


train Loss: 0.3112 Acc: 0.9152


Val Epoch 96: 100%|██████████| 19/19 [00:01<00:00,  9.68it/s, acc=0.865, loss=0.795]


val Loss: 0.7631 Acc: 0.8300
Epoch 97/100
----------


Train Epoch 97: 100%|██████████| 1296/1296 [06:26<00:00,  3.35it/s, acc=0.915, loss=0.31] 


train Loss: 0.3101 Acc: 0.9158


Val Epoch 97: 100%|██████████| 19/19 [00:02<00:00,  9.05it/s, acc=0.845, loss=0.58] 


val Loss: 0.5878 Acc: 0.8567
Epoch 98/100
----------


Train Epoch 98: 100%|██████████| 1296/1296 [06:47<00:00,  3.18it/s, acc=0.915, loss=0.315]


train Loss: 0.3149 Acc: 0.9151


Val Epoch 98: 100%|██████████| 19/19 [00:02<00:00,  9.01it/s, acc=0.899, loss=0.593]


val Loss: 0.5696 Acc: 0.8633
Epoch 99/100
----------


Train Epoch 99: 100%|██████████| 1296/1296 [06:47<00:00,  3.18it/s, acc=0.919, loss=0.302]


train Loss: 0.3020 Acc: 0.9194


Val Epoch 99: 100%|██████████| 19/19 [00:02<00:00,  9.11it/s, acc=0.899, loss=0.611]


val Loss: 0.5865 Acc: 0.8633
Epoch 100/100
----------


Train Epoch 100: 100%|██████████| 1296/1296 [06:50<00:00,  3.16it/s, acc=0.913, loss=0.315]


train Loss: 0.3155 Acc: 0.9138


Val Epoch 100: 100%|██████████| 19/19 [00:02<00:00,  8.39it/s, acc=0.839, loss=0.582]


val Loss: 0.5893 Acc: 0.8500
Training complete in 721m 12s
Best val Acc: 0.8733
Model saved as resnet152_cbam_best100.pth


In [7]:
# Load the trained model with CBAM
model = ResNet152_CBAM(num_classes=len(class_names))
model.load_state_dict(torch.load('resnet152_cbam_best100.pth'))
model = model.to(device)
model.eval()

# Define the transform (same as validation)
transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

val_folder = '/home/gpl/文件/Kezia/Visual Recognition/HW1/hw1-data/data/val/'
val_dataset = datasets.ImageFolder(val_folder, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1,
                                         shuffle=False)

correct = 0
total = 0
all_predictions = []

with torch.no_grad():
    for i, (inputs, labels) in enumerate(tqdm(val_loader)):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        true_label = labels.item()
        predicted_label = preds.item()

        correct += (predicted_label == true_label)
        total += 1

        # Get the file path of the current sample
        img_path, _ = val_dataset.samples[i]
        all_predictions.append((img_path.split('.')[0], predicted_label))

# Save predictions to a CSV
fname = "val_predictions_resnet152_w_CBAM.csv"
with open(fname, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['image_name', 'pred_label'])  # header
    writer.writerows(all_predictions)


accuracy = correct / total * 100
print(f'Validation Accuracy: {accuracy:.2f}%')

100%|██████████| 300/300 [00:23<00:00, 12.77it/s]

Validation Accuracy: 87.33%





In [8]:
# Load the Model
model = ResNet152_CBAM(num_classes=len(class_names))
model_path = 'resnet152_cbam_best100.pth'
model.load_state_dict(torch.load(model_path))
model = model.to(device)
model.eval()

# Transforms (same as val transforms)
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Dataset and Dataloader for test set
test_folder = '/home/gpl/文件/Kezia/Visual Recognition/HW1/hw1-data/data/test'

class TestDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.image_files = [f for f in os.listdir(folder_path)
                            if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.folder_path, img_name)
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)

        return image, img_name

# Create the dataset and loader
test_dataset = TestDataset(test_folder, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1,
                                          shuffle=False)

# Class names (same order as training classes)
class_names = image_datasets['train'].classes

# Run Inference on the Test Set
predictions = []

with torch.no_grad():
    for inputs, img_names in tqdm(test_loader, desc="Testing"):
        inputs = inputs.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        predicted_label = preds.item()
        predicted_class = class_names[predicted_label]

        # Save predictions
        predictions.append((os.path.basename(img_names[0]).split('.')[0],
                            predicted_class))

predictions.sort(key=lambda x: x[0])

# Save predictions to a CSV
with open('prediction_resnet152_w_CBAM.csv', mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerow(['image_name', 'pred_label'])  # header
    writer.writerows(predictions)

print("Predictions for test set saved to prediction_resnet152_w_CBAM.csv")


Testing: 100%|██████████| 2344/2344 [03:19<00:00, 11.77it/s]

Predictions for test set saved to prediction_resnet152_w_CBAM.csv





In [9]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the Model
model_path = 'resnet152_cbam_best100.pth'
model.load_state_dict(torch.load(model_path))

# Total number of parameters
total_params = sum(p.numel() for p in model.parameters())

# Number of trainable parameters
trainable_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)

print(f'Total parameters: {total_params}')
print(f'Trainable parameters: {trainable_params}')

Total parameters: 59045420
Trainable parameters: 59045420
