# import, loading data

In [1]:
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import torch.nn.functional as F
import random, time, torchprofile
import torchvision.models as models
from tqdm import tqdm
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score

In [2]:
# 定義 CustomImageNetDataset

class CustomImageNetDataset(Dataset):
    """用於 ImageNet 圖片及其註解的數據集類別。"""
    
    def __init__(self, annotation_file, image_directory, transforms=None):
        """
        初始化數據集。
        
        參數:
            annotation_file (str): 註解文件的路徑。
            image_directory (str): 圖片目錄的路徑。
            transforms (callable, optional): 要應用於圖片的轉換。
        """
        self.image_directory = image_directory
        self.transforms = transforms
        self.annotations = self._load_annotations(annotation_file)

    def _load_annotations(self, file_path):
        """從文件加載註解。"""
        with open(file_path, 'r') as f:
            annotations = f.read().splitlines()
        return annotations

    def __len__(self):
        """返回數據集中的圖片總數。"""
        return len(self.annotations)

    def __getitem__(self, index):
        """根據索引檢索圖片及其標籤。"""
        annotation = self.annotations[index]
        img_file, label = annotation.split()
        full_img_path = os.path.join(self.image_directory, img_file)
        image = Image.open(full_img_path).convert('RGB')
        
        if self.transforms:
            image = self.transforms(image)

        return image, int(label)

In [3]:
#Data augmentation

class SelectiveChannelDrop:
    def __init__(self):
        self.possible_channels = [
            (0, 1, 2),  # RGB
            (0, 1),     # RG
            (0, 2),     # RB
            (1, 2),     # GB
            (0,),       # R
            (1,),       # G
            (2,)        # B
        ]

    def __call__(self, image):
        channels = list(image.split())
        selected_channels = random.choice(self.possible_channels)
        new_channels = [channels[i] if i in selected_channels else Image.new('L', image.size, 'black') for i in range(3)]
        output_image = Image.merge('RGB', new_channels)
        return output_image

# 配置圖像變換

basic_transforms = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
])

training_transforms = transforms.Compose([
    transforms.RandomResizedCrop(96),
    transforms.RandomHorizontalFlip(1),
    transforms.RandomVerticalFlip(0.1),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.2),  
    transforms.RandomRotation(15),
    SelectiveChannelDrop(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


validation_transforms = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [5]:
# 把 train, validation data load進來

train_dataset = CustomImageNetDataset(annotation_file='data/images/train.txt', image_directory='data/images', transforms=training_transforms)
val_dataset = CustomImageNetDataset(annotation_file='data/images/val.txt', image_directory='data/images', transforms=validation_transforms)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# 我的設計模型

In [4]:
#設計 DynamicConv2D 

class DynamicConv2D(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, stride=1, padding=0, use_bias=True):
        super(DynamicConv2D, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.use_bias = use_bias

        # 初始化可調整的卷積核權重及偏置
        self.weights = nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size))
        if use_bias:
            self.biases = nn.Parameter(torch.randn(output_channels))
        else:
            self.register_parameter('biases', None)

    def forward(self, input_tensor):
        # 根據輸入的通道數量調整權重
        current_input_channels = input_tensor.size(1)
        if current_input_channels < self.input_channels:
            adjusted_weights = self.weights[:, :current_input_channels, :, :]
        else:
            adjusted_weights = self.weights

        # 執行卷積操作
        return F.conv2d(input_tensor, adjusted_weights, self.biases, self.stride, self.padding)


In [6]:
#把resnet18第一層改成 DynamicConv2D 

model = models.resnet18(weights=None)
model.conv1 = DynamicConv2D(input_channels=3, output_channels=64, kernel_size=7, stride=2, padding=3, use_bias=False)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.cuda()

ResNet(
  (conv1): DynamicConv2D()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [7]:
#定義train, validation

lr = 0.1
epochs = 30
momentum = 0.9
weight_decay = 1e-4
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)


def adjust_lr(optimizer, epoch):
    if epoch in [epochs*0.5, epochs*0.75, epochs*0.85]:
        for p in optimizer.param_groups:
            p['lr'] *= 0.1
            lr = p['lr']
        print('Change lr:'+str(lr))

def train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    adjust_lr(optimizer, epoch)
    train_loader_len = len(train_loader.dataset)
    train_loader_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training Epoch #{}".format(epoch))
    
    for batch_idx, (data, target) in train_loader_iter:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        train_loader_iter.set_postfix(loss=loss.item(), accuracy=100. * train_acc.item() / train_loader_len)
    
    print('Train Epoch: {}, Loss: {:.6f}, Accuracy: {:.2f}%'.format(epoch, avg_loss / len(train_loader), 100. * train_acc / train_loader_len))



def val(epoch):
    model.eval()
    test_loss = 0.
    correct = 0
    val_loader_iter = tqdm(val_loader, total=len(val_loader), desc="Validation Epoch #{}".format(epoch))
    
    with torch.no_grad():
        for data, label in val_loader_iter:
            data, label = data.cuda(), label.cuda()
            output = model(data)
            test_loss += F.cross_entropy(output, label, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(label.data.view_as(pred)).cpu().sum()
            val_loader_iter.set_postfix(loss=test_loss / len(val_loader.dataset), accuracy=100. * correct.item() / len(val_loader.dataset))
    
    test_loss /= len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    print('Validation Set: Average Loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss, accuracy))

    return accuracy


#開始訓練

best_val_acc = 0.
for i in range(epochs):
    train(i + 1)
    temp_acc = val(i + 1)
    if temp_acc > best_val_acc:
        best_val_acc = temp_acc
        torch.save(model.state_dict(), 'mydesign_best1.pt')
        print('Best Accuracy: {:.2f}%'.format(best_val_acc))

print('Final Best Accuracy: {:.2f}%'.format(best_val_acc))

Training Epoch #1: 100%|██████████| 990/990 [02:58<00:00,  5.55it/s, accuracy=4.65, loss=3.78] 


Train Epoch: 1, Loss: 3.815005, Accuracy: 4.65%


Validation Epoch #1: 100%|██████████| 8/8 [00:01<00:00,  7.68it/s, accuracy=7.56, loss=3.63]  


Validation Set: Average Loss: 3.6338, Accuracy: 7.56%
Best Accuracy: 7.56%


Training Epoch #2: 100%|██████████| 990/990 [02:54<00:00,  5.68it/s, accuracy=8.56, loss=3.72] 


Train Epoch: 2, Loss: 3.574237, Accuracy: 8.56%


Validation Epoch #2: 100%|██████████| 8/8 [00:00<00:00,  8.17it/s, accuracy=15.1, loss=3.22] 


Validation Set: Average Loss: 3.2164, Accuracy: 15.11%
Best Accuracy: 15.11%


Training Epoch #3: 100%|██████████| 990/990 [02:54<00:00,  5.67it/s, accuracy=11, loss=3.4]   


Train Epoch: 3, Loss: 3.447280, Accuracy: 11.01%


Validation Epoch #3: 100%|██████████| 8/8 [00:00<00:00,  8.13it/s, accuracy=17.3, loss=3.15] 


Validation Set: Average Loss: 3.1540, Accuracy: 17.33%
Best Accuracy: 17.33%


Training Epoch #4: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=12.7, loss=3.43] 


Train Epoch: 4, Loss: 3.358057, Accuracy: 12.66%


Validation Epoch #4: 100%|██████████| 8/8 [00:00<00:00,  8.25it/s, accuracy=15.6, loss=3.19]  


Validation Set: Average Loss: 3.1935, Accuracy: 15.56%


Training Epoch #5: 100%|██████████| 990/990 [02:54<00:00,  5.68it/s, accuracy=14.6, loss=3.46]


Train Epoch: 5, Loss: 3.259555, Accuracy: 14.56%


Validation Epoch #5: 100%|██████████| 8/8 [00:00<00:00,  8.26it/s, accuracy=19.1, loss=2.94] 


Validation Set: Average Loss: 2.9393, Accuracy: 19.11%
Best Accuracy: 19.11%


Training Epoch #6: 100%|██████████| 990/990 [02:56<00:00,  5.60it/s, accuracy=16.7, loss=3.49]


Train Epoch: 6, Loss: 3.163172, Accuracy: 16.72%


Validation Epoch #6: 100%|██████████| 8/8 [00:00<00:00,  8.20it/s, accuracy=20.4, loss=2.93] 


Validation Set: Average Loss: 2.9253, Accuracy: 20.44%
Best Accuracy: 20.44%


Training Epoch #7: 100%|██████████| 990/990 [02:56<00:00,  5.60it/s, accuracy=18.8, loss=3.04]


Train Epoch: 7, Loss: 3.056617, Accuracy: 18.80%


Validation Epoch #7: 100%|██████████| 8/8 [00:00<00:00,  8.29it/s, accuracy=24.2, loss=2.77] 


Validation Set: Average Loss: 2.7748, Accuracy: 24.22%
Best Accuracy: 24.22%


Training Epoch #8: 100%|██████████| 990/990 [02:56<00:00,  5.61it/s, accuracy=20.5, loss=2.89]


Train Epoch: 8, Loss: 2.973447, Accuracy: 20.49%


Validation Epoch #8: 100%|██████████| 8/8 [00:00<00:00,  8.20it/s, accuracy=24.9, loss=2.69] 


Validation Set: Average Loss: 2.6882, Accuracy: 24.89%
Best Accuracy: 24.89%


Training Epoch #9: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=22.3, loss=2.91]


Train Epoch: 9, Loss: 2.888124, Accuracy: 22.31%


Validation Epoch #9: 100%|██████████| 8/8 [00:00<00:00,  8.13it/s, accuracy=33.8, loss=2.46] 


Validation Set: Average Loss: 2.4637, Accuracy: 33.78%
Best Accuracy: 33.78%


Training Epoch #10: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=23.8, loss=2.84]


Train Epoch: 10, Loss: 2.816193, Accuracy: 23.84%


Validation Epoch #10: 100%|██████████| 8/8 [00:00<00:00,  8.21it/s, accuracy=29.6, loss=2.45] 


Validation Set: Average Loss: 2.4519, Accuracy: 29.56%


Training Epoch #11: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=25.4, loss=2.77]


Train Epoch: 11, Loss: 2.749706, Accuracy: 25.36%


Validation Epoch #11: 100%|██████████| 8/8 [00:00<00:00,  8.07it/s, accuracy=33.3, loss=2.35] 


Validation Set: Average Loss: 2.3485, Accuracy: 33.33%


Training Epoch #12: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=26.7, loss=2.3] 


Train Epoch: 12, Loss: 2.699769, Accuracy: 26.66%


Validation Epoch #12: 100%|██████████| 8/8 [00:00<00:00,  8.21it/s, accuracy=34.9, loss=2.26] 


Validation Set: Average Loss: 2.2553, Accuracy: 34.89%
Best Accuracy: 34.89%


Training Epoch #13: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=27.7, loss=2.59]


Train Epoch: 13, Loss: 2.644907, Accuracy: 27.68%


Validation Epoch #13: 100%|██████████| 8/8 [00:00<00:00,  8.11it/s, accuracy=38.9, loss=2.24] 


Validation Set: Average Loss: 2.2382, Accuracy: 38.89%
Best Accuracy: 38.89%


Training Epoch #14: 100%|██████████| 990/990 [02:58<00:00,  5.55it/s, accuracy=29, loss=2.7]   


Train Epoch: 14, Loss: 2.589126, Accuracy: 29.03%


Validation Epoch #14: 100%|██████████| 8/8 [00:00<00:00,  8.26it/s, accuracy=36.9, loss=2.23] 


Validation Set: Average Loss: 2.2306, Accuracy: 36.89%
Change lr:0.010000000000000002


Training Epoch #15: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=35.6, loss=2.01]


Train Epoch: 15, Loss: 2.307866, Accuracy: 35.57%


Validation Epoch #15: 100%|██████████| 8/8 [00:00<00:00,  8.11it/s, accuracy=49.3, loss=1.79] 


Validation Set: Average Loss: 1.7872, Accuracy: 49.33%
Best Accuracy: 49.33%


Training Epoch #16: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=38, loss=2.02]  


Train Epoch: 16, Loss: 2.206610, Accuracy: 37.98%


Validation Epoch #16: 100%|██████████| 8/8 [00:00<00:00,  8.27it/s, accuracy=49.8, loss=1.74] 


Validation Set: Average Loss: 1.7439, Accuracy: 49.78%
Best Accuracy: 49.78%


Training Epoch #17: 100%|██████████| 990/990 [02:57<00:00,  5.58it/s, accuracy=39.2, loss=2.27]


Train Epoch: 17, Loss: 2.162020, Accuracy: 39.25%


Validation Epoch #17: 100%|██████████| 8/8 [00:01<00:00,  7.79it/s, accuracy=49.8, loss=1.71] 


Validation Set: Average Loss: 1.7071, Accuracy: 49.78%


Training Epoch #18: 100%|██████████| 990/990 [02:58<00:00,  5.54it/s, accuracy=40, loss=2.16]  


Train Epoch: 18, Loss: 2.134986, Accuracy: 40.01%


Validation Epoch #18: 100%|██████████| 8/8 [00:00<00:00,  8.08it/s, accuracy=52, loss=1.66]   


Validation Set: Average Loss: 1.6612, Accuracy: 52.00%
Best Accuracy: 52.00%


Training Epoch #19: 100%|██████████| 990/990 [02:55<00:00,  5.64it/s, accuracy=40.5, loss=2.33]


Train Epoch: 19, Loss: 2.108404, Accuracy: 40.49%


Validation Epoch #19: 100%|██████████| 8/8 [00:00<00:00,  8.16it/s, accuracy=52.7, loss=1.64] 


Validation Set: Average Loss: 1.6375, Accuracy: 52.67%
Best Accuracy: 52.67%


Training Epoch #20: 100%|██████████| 990/990 [02:55<00:00,  5.64it/s, accuracy=41.3, loss=1.66]


Train Epoch: 20, Loss: 2.080119, Accuracy: 41.26%


Validation Epoch #20: 100%|██████████| 8/8 [00:00<00:00,  8.01it/s, accuracy=53.1, loss=1.63] 


Validation Set: Average Loss: 1.6263, Accuracy: 53.11%
Best Accuracy: 53.11%


Training Epoch #21: 100%|██████████| 990/990 [02:54<00:00,  5.66it/s, accuracy=41.6, loss=1.43]


Train Epoch: 21, Loss: 2.055221, Accuracy: 41.63%


Validation Epoch #21: 100%|██████████| 8/8 [00:01<00:00,  7.97it/s, accuracy=54.7, loss=1.6]  


Validation Set: Average Loss: 1.5996, Accuracy: 54.67%
Best Accuracy: 54.67%


Training Epoch #22: 100%|██████████| 990/990 [02:54<00:00,  5.67it/s, accuracy=42.2, loss=2.47]


Train Epoch: 22, Loss: 2.042154, Accuracy: 42.23%


Validation Epoch #22: 100%|██████████| 8/8 [00:00<00:00,  8.00it/s, accuracy=54, loss=1.57]   


Validation Set: Average Loss: 1.5681, Accuracy: 54.00%


Training Epoch #23: 100%|██████████| 990/990 [02:58<00:00,  5.54it/s, accuracy=42.9, loss=1.95]


Train Epoch: 23, Loss: 2.020481, Accuracy: 42.86%


Validation Epoch #23: 100%|██████████| 8/8 [00:00<00:00,  8.18it/s, accuracy=52.7, loss=1.55] 


Validation Set: Average Loss: 1.5497, Accuracy: 52.67%


Training Epoch #24: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=43.5, loss=1.74]


Train Epoch: 24, Loss: 2.000288, Accuracy: 43.47%


Validation Epoch #24: 100%|██████████| 8/8 [00:00<00:00,  8.22it/s, accuracy=55.3, loss=1.51] 


Validation Set: Average Loss: 1.5115, Accuracy: 55.33%
Best Accuracy: 55.33%


Training Epoch #25: 100%|██████████| 990/990 [03:01<00:00,  5.45it/s, accuracy=43.8, loss=2.18]


Train Epoch: 25, Loss: 1.986293, Accuracy: 43.82%


Validation Epoch #25: 100%|██████████| 8/8 [00:01<00:00,  7.63it/s, accuracy=55.6, loss=1.53] 


Validation Set: Average Loss: 1.5308, Accuracy: 55.56%
Best Accuracy: 55.56%


Training Epoch #26: 100%|██████████| 990/990 [02:59<00:00,  5.50it/s, accuracy=44.1, loss=2.03]


Train Epoch: 26, Loss: 1.964272, Accuracy: 44.10%


Validation Epoch #26: 100%|██████████| 8/8 [00:00<00:00,  8.26it/s, accuracy=55.1, loss=1.51] 


Validation Set: Average Loss: 1.5084, Accuracy: 55.11%


Training Epoch #27: 100%|██████████| 990/990 [02:54<00:00,  5.68it/s, accuracy=44.5, loss=1.89]


Train Epoch: 27, Loss: 1.948063, Accuracy: 44.54%


Validation Epoch #27: 100%|██████████| 8/8 [00:00<00:00,  8.05it/s, accuracy=54.7, loss=1.49] 


Validation Set: Average Loss: 1.4940, Accuracy: 54.67%


Training Epoch #28: 100%|██████████| 990/990 [02:58<00:00,  5.56it/s, accuracy=44.9, loss=2.34]


Train Epoch: 28, Loss: 1.933981, Accuracy: 44.85%


Validation Epoch #28: 100%|██████████| 8/8 [00:00<00:00,  8.12it/s, accuracy=53.6, loss=1.49] 


Validation Set: Average Loss: 1.4948, Accuracy: 53.56%


Training Epoch #29: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=45.7, loss=2.08]


Train Epoch: 29, Loss: 1.909190, Accuracy: 45.69%


Validation Epoch #29: 100%|██████████| 8/8 [00:00<00:00,  8.02it/s, accuracy=53.3, loss=1.55] 


Validation Set: Average Loss: 1.5511, Accuracy: 53.33%


Training Epoch #30: 100%|██████████| 990/990 [02:55<00:00,  5.63it/s, accuracy=45.9, loss=1.97]


Train Epoch: 30, Loss: 1.899953, Accuracy: 45.89%


Validation Epoch #30: 100%|██████████| 8/8 [00:01<00:00,  7.95it/s, accuracy=57.3, loss=1.45] 


Validation Set: Average Loss: 1.4464, Accuracy: 57.33%
Best Accuracy: 57.33%
Final Best Accuracy: 57.33%


In [8]:
#把 best model load 進來

model = models.resnet18(weights=None)
model.conv1 = DynamicConv2D(input_channels=3, output_channels=64, kernel_size=7, stride=2, padding=3, use_bias=False)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.load_state_dict(torch.load("mydesign_best1.pt"))
model.cuda()

ResNet(
  (conv1): DynamicConv2D()
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [19]:
# 把 test data load進來 ＆ data transforms 



def select_channels(image, channels='RGB'):
    if channels == 'R':
        return image[0, :, :].unsqueeze(0)
    elif channels == 'G':
        return image[1, :, :].unsqueeze(0)
    elif channels == 'B':
        return image[2, :, :].unsqueeze(0)
    elif channels == 'RG':
        return image[0:2, :, :]
    elif channels == 'RB':
        return torch.stack([image[0, :, :], image[2, :, :]], dim=0)
    elif channels == 'GB':
        return image[1:, :, :]
    elif channels == 'RGB':
        return image
    else:
        raise ValueError('Invalid channel selection')


def rgb_dataloader(set_channel = 'RGB'):
    test_transforms = transforms.Compose([
        transforms.Resize(96),
        transforms.CenterCrop(96),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        transforms.Lambda(lambda x: select_channels(x, set_channel))  # Change 'RGB' to other combinations as needed
    ])
    
    test_dataset = CustomImageNetDataset(annotation_file='data/images/test.txt', image_directory='data/images', transforms=test_transforms)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    return test_loader

In [20]:
#定義 test

def test(model, rgb_set):
    model.eval()
    correct = 0
    test_loader = rgb_dataloader(set_channel = rgb_set)
    test_loader_len = len(test_loader.dataset)
    test_loader_iter = tqdm(test_loader, total=len(test_loader), desc="Testing")
    
    all_preds = []
    all_targets = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for data, target in test_loader_iter:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            test_loader_iter.set_postfix(accuracy=100. * correct.item() / test_loader_len)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    accuracy = 100. * correct / test_loader_len
    
    # 計算 Precision, Recall 和 F1-score
    precision = 100. * precision_score(all_targets, all_preds, average='macro')
    recall = 100. * recall_score(all_targets, all_preds, average='macro')
    f1 = 100. * f1_score(all_targets, all_preds, average='macro')
    
    # 計算 FLOPS
    flops = torchprofile.profile_macs(model, torch.randn(1, *data.shape[1:]).cuda())
    
    return accuracy.item(), precision, recall, f1, flops, elapsed_time

#開始測試

rgb_list = ["RGB", "RG", "RB", "GB", "R", "G", "B"]
for rgb in rgb_list:
    resnet_acc, resnet_precision, resnet_recall, resnet_f1, resnet_flops, resnet_elapsed_time = test(model,rgb)
    print(f"RGB Set: {rgb:s}, Accuracy: {resnet_acc:.2f}%, Precision: {resnet_precision:.2f}%, Recall: {resnet_recall:.2f}%, F1 Score: {resnet_f1:.2f}%, FLOPS: {resnet_flops:d}, Elapsed Time: {resnet_elapsed_time:.2f} seconds")
    # resnet18_acc, resnet18_elapsed_time = test(model,rgb)
    # print(f"RGB Set: {rgb:s}, Accuracy: {resnet18_acc:.2f}%, Elapsed Time: {resnet18_elapsed_time:.2f} seconds")

Testing: 100%|██████████| 8/8 [00:01<00:00,  7.41it/s, accuracy=57.1]


RGB Set: RGB, Accuracy: 57.11%, Precision: 58.52%, Recall: 57.11%, F1 Score: 56.00%, FLOPS: 333585408, Elapsed Time: 1.08 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  7.61it/s, accuracy=44]  


RGB Set: RG, Accuracy: 44.00%, Precision: 51.40%, Recall: 44.00%, F1 Score: 42.50%, FLOPS: 326360064, Elapsed Time: 1.05 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  7.56it/s, accuracy=40.2]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


RGB Set: RB, Accuracy: 40.22%, Precision: 46.48%, Recall: 40.22%, F1 Score: 38.41%, FLOPS: 326360064, Elapsed Time: 1.06 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  7.50it/s, accuracy=37.8]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


RGB Set: GB, Accuracy: 37.78%, Precision: 44.57%, Recall: 37.78%, F1 Score: 35.56%, FLOPS: 326360064, Elapsed Time: 1.07 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  7.87it/s, accuracy=17.6]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


RGB Set: R, Accuracy: 17.56%, Precision: 22.61%, Recall: 17.56%, F1 Score: 15.33%, FLOPS: 319134720, Elapsed Time: 1.02 seconds


Testing: 100%|██████████| 8/8 [00:01<00:00,  8.00it/s, accuracy=15.6]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


RGB Set: G, Accuracy: 15.56%, Precision: 16.53%, Recall: 15.56%, F1 Score: 12.30%, FLOPS: 319134720, Elapsed Time: 1.00 seconds


Testing: 100%|██████████| 8/8 [00:00<00:00,  8.03it/s, accuracy=12.7]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


RGB Set: B, Accuracy: 12.67%, Precision: 14.57%, Recall: 12.67%, F1 Score: 10.32%, FLOPS: 319134720, Elapsed Time: 1.00 seconds


# Resnet18 模型

In [27]:
#引入 resnet18

model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.cuda()



ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [30]:
#定義train, validation

epochs = 30
# batch_size = 32
# test_batch_size = 32
lr = 0.1
momentum = 0.9
weight_decay = 1e-4
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

def adjust_lr(optimizer, epoch):
    if epoch in [epochs*0.5, epochs*0.75, epochs*0.85]:
        for p in optimizer.param_groups:
            p['lr'] *= 0.1
            lr = p['lr']
        print('Change lr:'+str(lr))

def train(epoch):
    model.train()
    avg_loss = 0.
    train_acc = 0.
    adjust_lr(optimizer, epoch)
    train_loader_len = len(train_loader.dataset)
    train_loader_iter = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training Epoch #{}".format(epoch))
    
    for batch_idx, (data, target) in train_loader_iter:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        avg_loss += loss.item()
        pred = output.data.max(1, keepdim=True)[1]
        train_acc += pred.eq(target.data.view_as(pred)).cpu().sum()
        loss.backward()
        optimizer.step()
        train_loader_iter.set_postfix(loss=loss.item(), accuracy=100. * train_acc.item() / train_loader_len)
    
    print('Train Epoch: {}, Loss: {:.6f}, Accuracy: {:.2f}%'.format(epoch, avg_loss / len(train_loader), 100. * train_acc / train_loader_len))



def val(epoch):
    model.eval()
    test_loss = 0.
    correct = 0
    val_loader_iter = tqdm(val_loader, total=len(val_loader), desc="Validation Epoch #{}".format(epoch))
    
    with torch.no_grad():
        for data, label in val_loader_iter:
            data, label = data.cuda(), label.cuda()
            output = model(data)
            test_loss += F.cross_entropy(output, label, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(label.data.view_as(pred)).cpu().sum()
            val_loader_iter.set_postfix(loss=test_loss / len(val_loader.dataset), accuracy=100. * correct.item() / len(val_loader.dataset))
    
    test_loss /= len(val_loader.dataset)
    accuracy = 100. * correct / len(val_loader.dataset)
    print('Validation Set: Average Loss: {:.4f}, Accuracy: {:.2f}%'.format(test_loss, accuracy))

    return accuracy


best_val_acc = 0.
for i in range(epochs):
    train(i + 1)
    temp_acc = val(i + 1)
    if temp_acc > best_val_acc:
        best_val_acc = temp_acc
        torch.save(model.state_dict(), 'resnet18_best.pt')
        print('Best Accuracy: {:.2f}%'.format(best_val_acc))

print('Final Best Accuracy: {:.2f}%'.format(best_val_acc))
    

Training Epoch #1: 100%|██████████| 990/990 [02:59<00:00,  5.53it/s, accuracy=3.6, loss=3.88]  


Train Epoch: 1, Loss: 3.877870, Accuracy: 3.60%


Validation Epoch #1: 100%|██████████| 8/8 [00:00<00:00,  8.17it/s, accuracy=8.22, loss=3.67] 


Validation Set: Average Loss: 3.6710, Accuracy: 8.22%
Best Accuracy: 8.22%


Training Epoch #2: 100%|██████████| 990/990 [02:59<00:00,  5.51it/s, accuracy=6.26, loss=3.67] 


Train Epoch: 2, Loss: 3.704179, Accuracy: 6.26%


Validation Epoch #2: 100%|██████████| 8/8 [00:00<00:00,  8.04it/s, accuracy=15.3, loss=3.37] 


Validation Set: Average Loss: 3.3713, Accuracy: 15.33%
Best Accuracy: 15.33%


Training Epoch #3: 100%|██████████| 990/990 [02:57<00:00,  5.58it/s, accuracy=8.92, loss=3.49] 


Train Epoch: 3, Loss: 3.555709, Accuracy: 8.92%


Validation Epoch #3: 100%|██████████| 8/8 [00:01<00:00,  7.95it/s, accuracy=14.7, loss=3.26] 


Validation Set: Average Loss: 3.2584, Accuracy: 14.67%


Training Epoch #4: 100%|██████████| 990/990 [02:56<00:00,  5.61it/s, accuracy=10.8, loss=2.92] 


Train Epoch: 4, Loss: 3.454451, Accuracy: 10.81%


Validation Epoch #4: 100%|██████████| 8/8 [00:00<00:00,  8.12it/s, accuracy=13.8, loss=3.21] 


Validation Set: Average Loss: 3.2095, Accuracy: 13.78%


Training Epoch #5: 100%|██████████| 990/990 [02:57<00:00,  5.59it/s, accuracy=12.3, loss=3.69]


Train Epoch: 5, Loss: 3.379856, Accuracy: 12.32%


Validation Epoch #5: 100%|██████████| 8/8 [00:00<00:00,  8.02it/s, accuracy=21.6, loss=3.05] 


Validation Set: Average Loss: 3.0533, Accuracy: 21.56%
Best Accuracy: 21.56%


Training Epoch #6: 100%|██████████| 990/990 [02:55<00:00,  5.63it/s, accuracy=13.7, loss=3.15]


Train Epoch: 6, Loss: 3.310372, Accuracy: 13.74%


Validation Epoch #6: 100%|██████████| 8/8 [00:00<00:00,  8.11it/s, accuracy=17.8, loss=2.99] 


Validation Set: Average Loss: 2.9895, Accuracy: 17.78%


Training Epoch #7: 100%|██████████| 990/990 [02:57<00:00,  5.57it/s, accuracy=15.4, loss=3.59]


Train Epoch: 7, Loss: 3.234963, Accuracy: 15.42%


Validation Epoch #7: 100%|██████████| 8/8 [00:01<00:00,  7.90it/s, accuracy=23.3, loss=2.84] 


Validation Set: Average Loss: 2.8385, Accuracy: 23.33%
Best Accuracy: 23.33%


Training Epoch #8: 100%|██████████| 990/990 [02:56<00:00,  5.61it/s, accuracy=16.8, loss=3.53]


Train Epoch: 8, Loss: 3.166563, Accuracy: 16.81%


Validation Epoch #8: 100%|██████████| 8/8 [00:00<00:00,  8.04it/s, accuracy=24.7, loss=2.83] 


Validation Set: Average Loss: 2.8339, Accuracy: 24.67%
Best Accuracy: 24.67%


Training Epoch #9: 100%|██████████| 990/990 [02:56<00:00,  5.59it/s, accuracy=18, loss=3.26]  


Train Epoch: 9, Loss: 3.103732, Accuracy: 18.05%


Validation Epoch #9: 100%|██████████| 8/8 [00:00<00:00,  8.16it/s, accuracy=26.2, loss=2.76] 


Validation Set: Average Loss: 2.7557, Accuracy: 26.22%
Best Accuracy: 26.22%


Training Epoch #10: 100%|██████████| 990/990 [02:56<00:00,  5.60it/s, accuracy=19, loss=2.89]  


Train Epoch: 10, Loss: 3.047830, Accuracy: 18.98%


Validation Epoch #10: 100%|██████████| 8/8 [00:01<00:00,  7.94it/s, accuracy=30.2, loss=2.59] 


Validation Set: Average Loss: 2.5937, Accuracy: 30.22%
Best Accuracy: 30.22%


Training Epoch #11: 100%|██████████| 990/990 [02:57<00:00,  5.59it/s, accuracy=20.3, loss=3.18]


Train Epoch: 11, Loss: 2.993371, Accuracy: 20.30%


Validation Epoch #11: 100%|██████████| 8/8 [00:01<00:00,  8.00it/s, accuracy=31.3, loss=2.65] 


Validation Set: Average Loss: 2.6488, Accuracy: 31.33%
Best Accuracy: 31.33%


Training Epoch #12: 100%|██████████| 990/990 [02:57<00:00,  5.58it/s, accuracy=21.3, loss=2.79]


Train Epoch: 12, Loss: 2.943518, Accuracy: 21.29%


Validation Epoch #12: 100%|██████████| 8/8 [00:01<00:00,  7.48it/s, accuracy=29.8, loss=2.56] 


Validation Set: Average Loss: 2.5565, Accuracy: 29.78%


Training Epoch #13: 100%|██████████| 990/990 [02:59<00:00,  5.51it/s, accuracy=22.4, loss=2.79]


Train Epoch: 13, Loss: 2.895362, Accuracy: 22.41%


Validation Epoch #13: 100%|██████████| 8/8 [00:00<00:00,  8.05it/s, accuracy=26, loss=2.66]   


Validation Set: Average Loss: 2.6567, Accuracy: 26.00%


Training Epoch #14: 100%|██████████| 990/990 [03:05<00:00,  5.34it/s, accuracy=23.3, loss=3.27]


Train Epoch: 14, Loss: 2.847849, Accuracy: 23.32%


Validation Epoch #14: 100%|██████████| 8/8 [00:01<00:00,  7.85it/s, accuracy=32, loss=2.4]    


Validation Set: Average Loss: 2.4016, Accuracy: 32.00%
Best Accuracy: 32.00%
Change lr:0.010000000000000002


Training Epoch #15: 100%|██████████| 990/990 [03:02<00:00,  5.44it/s, accuracy=29.3, loss=2.73]


Train Epoch: 15, Loss: 2.590407, Accuracy: 29.28%


Validation Epoch #15: 100%|██████████| 8/8 [00:00<00:00,  8.03it/s, accuracy=42, loss=2.08]   


Validation Set: Average Loss: 2.0810, Accuracy: 42.00%
Best Accuracy: 42.00%


Training Epoch #16: 100%|██████████| 990/990 [02:57<00:00,  5.57it/s, accuracy=31, loss=2.89]  


Train Epoch: 16, Loss: 2.509489, Accuracy: 30.97%


Validation Epoch #16: 100%|██████████| 8/8 [00:01<00:00,  7.99it/s, accuracy=44.7, loss=2.03] 


Validation Set: Average Loss: 2.0338, Accuracy: 44.67%
Best Accuracy: 44.67%


Training Epoch #17: 100%|██████████| 990/990 [02:59<00:00,  5.52it/s, accuracy=32.1, loss=2.69]


Train Epoch: 17, Loss: 2.469996, Accuracy: 32.06%


Validation Epoch #17: 100%|██████████| 8/8 [00:01<00:00,  7.61it/s, accuracy=43.3, loss=2.04] 


Validation Set: Average Loss: 2.0419, Accuracy: 43.33%


Training Epoch #18: 100%|██████████| 990/990 [02:59<00:00,  5.51it/s, accuracy=32.6, loss=2.59]


Train Epoch: 18, Loss: 2.441964, Accuracy: 32.62%


Validation Epoch #18: 100%|██████████| 8/8 [00:00<00:00,  8.14it/s, accuracy=43.3, loss=2]    


Validation Set: Average Loss: 1.9982, Accuracy: 43.33%


Training Epoch #19: 100%|██████████| 990/990 [02:57<00:00,  5.56it/s, accuracy=33.4, loss=2.43]


Train Epoch: 19, Loss: 2.408815, Accuracy: 33.35%


Validation Epoch #19: 100%|██████████| 8/8 [00:00<00:00,  8.04it/s, accuracy=44.9, loss=1.98] 


Validation Set: Average Loss: 1.9786, Accuracy: 44.89%
Best Accuracy: 44.89%


Training Epoch #20: 100%|██████████| 990/990 [02:54<00:00,  5.69it/s, accuracy=33.8, loss=2.04]


Train Epoch: 20, Loss: 2.382054, Accuracy: 33.83%


Validation Epoch #20: 100%|██████████| 8/8 [00:00<00:00,  8.07it/s, accuracy=46.2, loss=1.94] 


Validation Set: Average Loss: 1.9419, Accuracy: 46.22%
Best Accuracy: 46.22%


Training Epoch #21: 100%|██████████| 990/990 [03:04<00:00,  5.36it/s, accuracy=34.5, loss=2.69]


Train Epoch: 21, Loss: 2.360101, Accuracy: 34.51%


Validation Epoch #21: 100%|██████████| 8/8 [00:01<00:00,  7.92it/s, accuracy=46, loss=1.91]   


Validation Set: Average Loss: 1.9068, Accuracy: 46.00%


Training Epoch #22: 100%|██████████| 990/990 [02:59<00:00,  5.51it/s, accuracy=35.1, loss=2.07]


Train Epoch: 22, Loss: 2.334501, Accuracy: 35.15%


Validation Epoch #22: 100%|██████████| 8/8 [00:00<00:00,  8.08it/s, accuracy=44.9, loss=1.93] 


Validation Set: Average Loss: 1.9285, Accuracy: 44.89%


Training Epoch #23: 100%|██████████| 990/990 [03:01<00:00,  5.44it/s, accuracy=35.5, loss=2.22]


Train Epoch: 23, Loss: 2.312828, Accuracy: 35.54%


Validation Epoch #23: 100%|██████████| 8/8 [00:01<00:00,  7.99it/s, accuracy=48.7, loss=1.85] 


Validation Set: Average Loss: 1.8508, Accuracy: 48.67%
Best Accuracy: 48.67%


Training Epoch #24: 100%|██████████| 990/990 [02:58<00:00,  5.54it/s, accuracy=36.4, loss=2.55]


Train Epoch: 24, Loss: 2.289448, Accuracy: 36.36%


Validation Epoch #24: 100%|██████████| 8/8 [00:00<00:00,  8.11it/s, accuracy=48.2, loss=1.9]  


Validation Set: Average Loss: 1.8967, Accuracy: 48.22%


Training Epoch #25: 100%|██████████| 990/990 [02:55<00:00,  5.65it/s, accuracy=36.9, loss=2.39]


Train Epoch: 25, Loss: 2.265477, Accuracy: 36.93%


Validation Epoch #25: 100%|██████████| 8/8 [00:01<00:00,  7.63it/s, accuracy=47.8, loss=1.82] 


Validation Set: Average Loss: 1.8178, Accuracy: 47.78%


Training Epoch #26: 100%|██████████| 990/990 [02:57<00:00,  5.58it/s, accuracy=37.1, loss=2.48]


Train Epoch: 26, Loss: 2.246955, Accuracy: 37.09%


Validation Epoch #26: 100%|██████████| 8/8 [00:00<00:00,  8.16it/s, accuracy=49.3, loss=1.83] 


Validation Set: Average Loss: 1.8269, Accuracy: 49.33%
Best Accuracy: 49.33%


Training Epoch #27: 100%|██████████| 990/990 [02:54<00:00,  5.66it/s, accuracy=37.8, loss=2.36]


Train Epoch: 27, Loss: 2.225855, Accuracy: 37.80%


Validation Epoch #27: 100%|██████████| 8/8 [00:00<00:00,  8.12it/s, accuracy=45.8, loss=1.84] 


Validation Set: Average Loss: 1.8445, Accuracy: 45.78%


Training Epoch #28: 100%|██████████| 990/990 [02:54<00:00,  5.66it/s, accuracy=38.3, loss=2.3] 


Train Epoch: 28, Loss: 2.207562, Accuracy: 38.29%


Validation Epoch #28: 100%|██████████| 8/8 [00:00<00:00,  8.19it/s, accuracy=50.7, loss=1.71] 


Validation Set: Average Loss: 1.7112, Accuracy: 50.67%
Best Accuracy: 50.67%


Training Epoch #29: 100%|██████████| 990/990 [02:55<00:00,  5.63it/s, accuracy=38.9, loss=2.21]


Train Epoch: 29, Loss: 2.186268, Accuracy: 38.94%


Validation Epoch #29: 100%|██████████| 8/8 [00:01<00:00,  7.97it/s, accuracy=48.9, loss=1.71] 


Validation Set: Average Loss: 1.7073, Accuracy: 48.89%


Training Epoch #30: 100%|██████████| 990/990 [02:57<00:00,  5.56it/s, accuracy=39.8, loss=1.96]


Train Epoch: 30, Loss: 2.155499, Accuracy: 39.77%


Validation Epoch #30: 100%|██████████| 8/8 [00:01<00:00,  7.89it/s, accuracy=52.4, loss=1.65] 


Validation Set: Average Loss: 1.6484, Accuracy: 52.44%
Best Accuracy: 52.44%
Final Best Accuracy: 52.44%


In [31]:
#把 best model load 進來

model = models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(model.fc.in_features, 50)
model.load_state_dict(torch.load("resnet18_best.pt"))
model.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [33]:
# 把 test data load進來 ＆ data transforms 

test_augmentations = transforms.Compose([
    transforms.Resize(96),
    transforms.CenterCrop(96),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_dataset = CustomImageNetDataset(annotation_file='data/images/test.txt', image_directory='data/images', transforms=test_augmentations)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [36]:
#定義test

def test(model):
    model.eval()
    correct = 0
    test_loader_len = len(test_loader.dataset)
    test_loader_iter = tqdm(test_loader, total=len(test_loader), desc="Testing")
    
    all_preds = []
    all_targets = []
    
    start_time = time.time()
    
    with torch.no_grad():
        for data, target in test_loader_iter:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            
            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            test_loader_iter.set_postfix(accuracy=100. * correct.item() / test_loader_len)
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    accuracy = 100. * correct / test_loader_len
    
    # 計算 Precision, Recall 和 F1-score
    precision = 100. * precision_score(all_targets, all_preds, average='macro')
    recall = 100. * recall_score(all_targets, all_preds, average='macro')
    f1 = 100. * f1_score(all_targets, all_preds, average='macro')
    
    # 計算 FLOPS
    flops = torchprofile.profile_macs(model, torch.randn(1, *data.shape[1:]).cuda())
    
    return accuracy.item(), precision, recall, f1, flops, elapsed_time

#開始測試

resnet_acc, resnet_precision, resnet_recall, resnet_f1, resnet_flops, resnet_elapsed_time = test(model)
print(f"Accuracy: {resnet_acc:.2f}%, Precision: {resnet_precision:.2f}%, Recall: {resnet_recall:.2f}%, F1 Score: {resnet_f1:.2f}%, FLOPS: {resnet_flops:d}, Elapsed Time: {resnet_elapsed_time:.2f} seconds")

Testing: 100%|██████████| 8/8 [00:00<00:00,  8.14it/s, accuracy=51.1]


Accuracy: 51.11%, Precision: 52.57%, Recall: 51.11%, F1 Score: 49.78%, FLOPS: 333585408, Elapsed Time: 0.98 seconds
