In [1]:
# pip install torch==2.3.0 torchvision==0.18.0 tensorboard==2.16.2 scikit-learn==1.5.0 pandas==2.2.2 tqdm==4.66.4

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [3]:
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

In [4]:
from torch.utils.data import Dataset, DataLoader

In [5]:
import os

from PIL import Image
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

# 1. Prepare data

- Step 1: Read image from hard drive
- Step 2: Transform image
- Step 3: Prepare image label

In [6]:
class MyBrainTumorDataset(Dataset):
    def __init__(self, data_folder, csv_path):
        self.data_folder = data_folder
        self.image_names = [name for name in os.listdir(data_folder) if name.endswith('.jpg')]

        # Transformation => Step 2
        self.transform = transforms.Compose([
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        # Label => Step 3
        self.label_df = pd.read_csv(csv_path, usecols=['image_name', 'label'])
        self.image_names = [name for name in self.image_names if name in self.label_df.image_name.to_list()]
        
    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]

        # Step 1: Read image from hard drive
        image = Image.open(os.path.join(self.data_folder, image_name))
        
        # Step 2: Transform image
        transformed_image = self.transform(image)

        # Step 3: Prepare image label
        label_str = self.label_df[self.label_df.image_name == image_name]['label'].values[0]
        if label_str == 'tumor':
            label = torch.tensor([0, 1], dtype=torch.float32)
        else:
            label = torch.tensor([1, 0], dtype=torch.float32)

        return transformed_image, label

In [7]:
train_dataset = MyBrainTumorDataset(
    data_folder='data/brain_tumor_dataset/train',
    csv_path='data/brain_tumor_dataset/brain_multi.csv'
)

In [8]:
len(train_dataset)

5744

In [9]:
train_dataset[1]

(tensor([[[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          ...,
          [-1.0000, -1.0000, -1.0000,  ..., -0.9922, -0.7020, -0.3176],
          [-0.9922, -1.0000, -0.9922,  ..., -1.0000, -0.7804, -0.4039],
          [-0.9922, -1.0000, -1.0000,  ..., -1.0000, -0.9294, -0.7569]],
 
         [[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          ...,
          [-1.0000, -1.0000, -1.0000,  ..., -0.9922, -0.7020, -0.3176],
          [-0.9922, -1.0000, -0.9922,  ..., -1.0000, -0.7804, -0.4039],
          [-0.9922, -1.0000, -1.0000,  ..., -1.0000, -0.9294, -0.7569]],
 
         [[-1.0000, -1.0000, -1.0000,  ..., -1.0000, -1.0000, -1.0000],
          [-1.0000, -1.0000,

In [10]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=0
)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x16963ae20>

# 2. Build model

In [11]:
class MyCustomCNN(nn.Module):
    def __init__(self):
        super(MyCustomCNN, self).__init__()

        # Image features extractor: Embed image into vector
        self.conv_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.relu_1 = nn.ReLU()
        self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.relu_2 = nn.ReLU()
        self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.relu_3 = nn.ReLU()
        self.pool_3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv_4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.relu_4 = nn.ReLU()
        self.pool_4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Classification
        self.linear_1 = nn.Linear(128 * 8 * 8, 128)
        self.relu_5 = nn.ReLU()
        self.linear_2 = nn.Linear(128, 2)
        self.softmax = nn.Softmax()

    def forward(self, x):
        x = self.conv_1(x)
        x = self.relu_1(x)
        x = self.pool_1(x)

        x = self.conv_2(x)
        x = self.relu_2(x)
        x = self.pool_2(x)

        x = self.conv_3(x)
        x = self.relu_3(x)
        x = self.pool_3(x)

        x = self.conv_4(x)
        x = self.relu_4(x)
        x = self.pool_4(x)

        x = x.view(-1, 128 * 8 * 8)
        x = self.linear_1(x)
        x = self.relu_5(x)

        x = self.linear_2(x)
        x = self.softmax(x)
        return x

In [12]:
# log_name = 'my_custom_cnn'
# model = MyCustomCNN()
# model

In [13]:
# model.cuda()

In [14]:
class MyResNetCNN(nn.Module):
    def __init__(self):
        super(MyResNetCNN, self).__init__()
        
        self.backbone = models.resnet18(pretrained=True)
        num_features = self.backbone.fc.in_features
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])

        self.linear = nn.Linear(num_features, 2)
        self.softmax = nn.Softmax()

    def forward(self, x):
        x = self.backbone(x)
        x = x.squeeze()
        x = self.linear(x)
        x = self.softmax(x)
        return x

In [15]:
log_name = 'my_resnet_cnn'
model = MyResNetCNN()
model



MyResNetCNN(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [16]:
# model.cuda()

# 3. Train model

## 3.1. Init loss function and optimizer

In [17]:
loss_func = nn.CrossEntropyLoss()
loss_func

CrossEntropyLoss()

In [18]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    momentum: 0.9
    nesterov: False
    weight_decay: 0
)

## 3.2. Train and evaluate model

In [19]:
writer = SummaryWriter(log_dir=f'train_logs/{log_name}')
writer

<torch.utils.tensorboard.writer.SummaryWriter at 0x1696e0a00>

In [20]:
num_epoch = 10

In [21]:
model

MyResNetCNN(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [22]:
ckpt_folfer = f'ckpt/{log_name}'
os.makedirs(ckpt_folfer, exist_ok=True)

In [23]:
for epoch in range(num_epoch):
    # Train model
    model.train()
    for iteration_, (image, label) in enumerate(tqdm(train_dataloader, total=len(train_dataloader))):
        
        # Bước 1: Optimizer zero grad
        optimizer.zero_grad()

        # Bước 2: Foward data to model
        # image.cuda()
        pred = model(image)

        # Bước 3: Tính giá trị loss
        loss_value = loss_func(pred, label)

        # Bước 4: Cập nhật trọng số của mô hình
        loss_value.backward()
        optimizer.step()
        
        global_iteration = epoch * len(train_dataloader) + iteration_
        writer.add_scalar('train_loss_iter', loss_value, global_iteration)

    # Bước 5: (Tuỳ chọn) In các thông số ra ngoài màn hình
    print(f'Epoch={epoch}', f'Training loss={loss_value.item()}')
    writer.add_scalar('train_loss_epoch', loss_value, epoch)

    # Evaluate model
    model.eval()
    with torch.no_grad():
        loss_sum = 0
        pred_list, label_list = [], []
        for image, label in tqdm(train_dataloader, total=len(train_dataloader)):
            # image.cuda()
            pred = model(image)
            loss = loss_func(pred, label)
            loss_sum += loss.item()

            pred_list.append(pred)
            label_list.append(label)

        print(f'Test loss {loss_sum / len(train_dataloader)}')
        writer.add_scalar('test_loss_epoch', loss_value, epoch)
        
        # Calculate metrics
        final_pred = torch.concat(pred_list)
        final_pred = torch.argmax(final_pred, axis=1)
        final_label = torch.concat(label_list)
        final_label = torch.argmax(final_label, axis=1)
        

        epoch_accuracy_score = accuracy_score(final_pred, final_label)
        writer.add_scalar('test_accuracy_score_epoch', epoch_accuracy_score, epoch)

        epoch_precision_score = precision_score(final_pred, final_label)
        writer.add_scalar('test_precision_score_epoch', epoch_precision_score, epoch)

        epoch_recall_score = recall_score(final_pred, final_label)
        writer.add_scalar('test_recall_score_epoch', epoch_recall_score, epoch)

        epoch_f1_score = f1_score(final_pred, final_label)
        writer.add_scalar('test_f1_score_epoch', epoch_f1_score, epoch)

        print(classification_report(final_pred, final_label))
        
        torch.save(model.state_dict(), os.path.join(ckpt_folfer, f'ckpt_{epoch}.pth'))

  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:08<00:00,  1.31it/s]


Epoch=0 Training loss=0.3939453661441803


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:47<00:00,  1.89it/s]


Test loss 0.4108459111717012
              precision    recall  f1-score   support

           0       0.66      0.89      0.76       892
           1       0.98      0.91      0.95      4852

    accuracy                           0.91      5744
   macro avg       0.82      0.90      0.85      5744
weighted avg       0.93      0.91      0.92      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:06<00:00,  1.35it/s]


Epoch=1 Training loss=0.3578239977359772


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:49<00:00,  1.82it/s]


Test loss 0.35677065120802987
              precision    recall  f1-score   support

           0       0.83      0.98      0.90      1021
           1       1.00      0.96      0.98      4723

    accuracy                           0.96      5744
   macro avg       0.91      0.97      0.94      5744
weighted avg       0.97      0.96      0.96      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:09<00:00,  1.29it/s]


Epoch=2 Training loss=0.33497485518455505


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:49<00:00,  1.83it/s]


Test loss 0.33919177883201174
              precision    recall  f1-score   support

           0       0.92      0.98      0.95      1141
           1       0.99      0.98      0.99      4603

    accuracy                           0.98      5744
   macro avg       0.96      0.98      0.97      5744
weighted avg       0.98      0.98      0.98      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:06<00:00,  1.36it/s]


Epoch=3 Training loss=0.3365288972854614


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:47<00:00,  1.88it/s]


Test loss 0.33052301671769885
              precision    recall  f1-score   support

           0       0.95      0.98      0.97      1166
           1       1.00      0.99      0.99      4578

    accuracy                           0.99      5744
   macro avg       0.97      0.99      0.98      5744
weighted avg       0.99      0.99      0.99      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:06<00:00,  1.36it/s]


Epoch=4 Training loss=0.3173868656158447


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:47<00:00,  1.90it/s]


Test loss 0.32292393777105544
              precision    recall  f1-score   support

           0       0.98      0.99      0.98      1189
           1       1.00      0.99      1.00      4555

    accuracy                           0.99      5744
   macro avg       0.99      0.99      0.99      5744
weighted avg       0.99      0.99      0.99      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:06<00:00,  1.35it/s]


Epoch=5 Training loss=0.3599858582019806


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:48<00:00,  1.86it/s]


Test loss 0.32023748788568707
              precision    recall  f1-score   support

           0       0.98      0.99      0.99      1190
           1       1.00      0.99      1.00      4554

    accuracy                           0.99      5744
   macro avg       0.99      0.99      0.99      5744
weighted avg       0.99      0.99      0.99      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:07<00:00,  1.33it/s]


Epoch=6 Training loss=0.3581777811050415


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:47<00:00,  1.90it/s]


Test loss 0.3191981659995185
              precision    recall  f1-score   support

           0       0.98      0.99      0.99      1191
           1       1.00      0.99      1.00      4553

    accuracy                           0.99      5744
   macro avg       0.99      0.99      0.99      5744
weighted avg       1.00      0.99      0.99      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:08<00:00,  1.32it/s]


Epoch=7 Training loss=0.31391486525535583


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:49<00:00,  1.81it/s]


Test loss 0.3176514271232817
              precision    recall  f1-score   support

           0       0.99      0.99      0.99      1198
           1       1.00      1.00      1.00      4546

    accuracy                           1.00      5744
   macro avg       0.99      1.00      0.99      5744
weighted avg       1.00      1.00      1.00      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:08<00:00,  1.31it/s]


Epoch=8 Training loss=0.36076733469963074


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:49<00:00,  1.81it/s]


Test loss 0.31745612422625225
              precision    recall  f1-score   support

           0       0.99      0.99      0.99      1207
           1       1.00      1.00      1.00      4537

    accuracy                           1.00      5744
   macro avg       0.99      1.00      1.00      5744
weighted avg       1.00      1.00      1.00      5744



  return self._call_impl(*args, **kwargs)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [01:06<00:00,  1.35it/s]


Epoch=9 Training loss=0.3149544596672058


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:47<00:00,  1.89it/s]

Test loss 0.3164250754647785
              precision    recall  f1-score   support

           0       0.99      1.00      0.99      1203
           1       1.00      1.00      1.00      4541

    accuracy                           1.00      5744
   macro avg       1.00      1.00      1.00      5744
weighted avg       1.00      1.00      1.00      5744






# 4. Predict new data

In [24]:
state_dict = torch.load(f'{ckpt_folfer}/ckpt_3.pth', map_location='cpu')
state_dict

OrderedDict([('backbone.0.weight',
              tensor([[[[-1.0497e-02, -5.7113e-03, -1.7750e-03,  ...,  5.6866e-02,
                          1.7047e-02, -1.1706e-02],
                        [ 1.0042e-02,  9.4565e-03, -1.1070e-01,  ..., -2.7198e-01,
                         -1.2945e-01,  4.3973e-03],
                        [-8.1629e-03,  5.9719e-02,  2.9556e-01,  ...,  5.1936e-01,
                          2.5675e-01,  6.4104e-02],
                        ...,
                        [-2.8634e-02,  1.6416e-02,  7.2694e-02,  ..., -3.3211e-01,
                         -4.1960e-01, -2.5638e-01],
                        [ 2.9612e-02,  4.1070e-02,  6.3200e-02,  ...,  4.1529e-01,
                          3.9475e-01,  1.6746e-01],
                        [-1.4377e-02, -3.1112e-03, -2.2976e-02,  ..., -1.4909e-01,
                         -8.0317e-02, -4.3318e-03]],
              
                       [[-1.1479e-02, -2.6201e-02, -3.4613e-02,  ...,  3.2770e-02,
                          6

In [25]:
new_model_1 = MyCustomCNN()

In [26]:
new_model_2 = MyResNetCNN()



In [27]:
new_model_2.load_state_dict(state_dict)

<All keys matched successfully>

In [28]:
new_model_1.load_state_dict(state_dict)

RuntimeError: Error(s) in loading state_dict for MyCustomCNN:
	Missing key(s) in state_dict: "conv_1.weight", "conv_1.bias", "conv_2.weight", "conv_2.bias", "conv_3.weight", "conv_3.bias", "conv_4.weight", "conv_4.bias", "linear_1.weight", "linear_1.bias", "linear_2.weight", "linear_2.bias". 
	Unexpected key(s) in state_dict: "backbone.0.weight", "backbone.1.weight", "backbone.1.bias", "backbone.1.running_mean", "backbone.1.running_var", "backbone.1.num_batches_tracked", "backbone.4.0.conv1.weight", "backbone.4.0.bn1.weight", "backbone.4.0.bn1.bias", "backbone.4.0.bn1.running_mean", "backbone.4.0.bn1.running_var", "backbone.4.0.bn1.num_batches_tracked", "backbone.4.0.conv2.weight", "backbone.4.0.bn2.weight", "backbone.4.0.bn2.bias", "backbone.4.0.bn2.running_mean", "backbone.4.0.bn2.running_var", "backbone.4.0.bn2.num_batches_tracked", "backbone.4.1.conv1.weight", "backbone.4.1.bn1.weight", "backbone.4.1.bn1.bias", "backbone.4.1.bn1.running_mean", "backbone.4.1.bn1.running_var", "backbone.4.1.bn1.num_batches_tracked", "backbone.4.1.conv2.weight", "backbone.4.1.bn2.weight", "backbone.4.1.bn2.bias", "backbone.4.1.bn2.running_mean", "backbone.4.1.bn2.running_var", "backbone.4.1.bn2.num_batches_tracked", "backbone.5.0.conv1.weight", "backbone.5.0.bn1.weight", "backbone.5.0.bn1.bias", "backbone.5.0.bn1.running_mean", "backbone.5.0.bn1.running_var", "backbone.5.0.bn1.num_batches_tracked", "backbone.5.0.conv2.weight", "backbone.5.0.bn2.weight", "backbone.5.0.bn2.bias", "backbone.5.0.bn2.running_mean", "backbone.5.0.bn2.running_var", "backbone.5.0.bn2.num_batches_tracked", "backbone.5.0.downsample.0.weight", "backbone.5.0.downsample.1.weight", "backbone.5.0.downsample.1.bias", "backbone.5.0.downsample.1.running_mean", "backbone.5.0.downsample.1.running_var", "backbone.5.0.downsample.1.num_batches_tracked", "backbone.5.1.conv1.weight", "backbone.5.1.bn1.weight", "backbone.5.1.bn1.bias", "backbone.5.1.bn1.running_mean", "backbone.5.1.bn1.running_var", "backbone.5.1.bn1.num_batches_tracked", "backbone.5.1.conv2.weight", "backbone.5.1.bn2.weight", "backbone.5.1.bn2.bias", "backbone.5.1.bn2.running_mean", "backbone.5.1.bn2.running_var", "backbone.5.1.bn2.num_batches_tracked", "backbone.6.0.conv1.weight", "backbone.6.0.bn1.weight", "backbone.6.0.bn1.bias", "backbone.6.0.bn1.running_mean", "backbone.6.0.bn1.running_var", "backbone.6.0.bn1.num_batches_tracked", "backbone.6.0.conv2.weight", "backbone.6.0.bn2.weight", "backbone.6.0.bn2.bias", "backbone.6.0.bn2.running_mean", "backbone.6.0.bn2.running_var", "backbone.6.0.bn2.num_batches_tracked", "backbone.6.0.downsample.0.weight", "backbone.6.0.downsample.1.weight", "backbone.6.0.downsample.1.bias", "backbone.6.0.downsample.1.running_mean", "backbone.6.0.downsample.1.running_var", "backbone.6.0.downsample.1.num_batches_tracked", "backbone.6.1.conv1.weight", "backbone.6.1.bn1.weight", "backbone.6.1.bn1.bias", "backbone.6.1.bn1.running_mean", "backbone.6.1.bn1.running_var", "backbone.6.1.bn1.num_batches_tracked", "backbone.6.1.conv2.weight", "backbone.6.1.bn2.weight", "backbone.6.1.bn2.bias", "backbone.6.1.bn2.running_mean", "backbone.6.1.bn2.running_var", "backbone.6.1.bn2.num_batches_tracked", "backbone.7.0.conv1.weight", "backbone.7.0.bn1.weight", "backbone.7.0.bn1.bias", "backbone.7.0.bn1.running_mean", "backbone.7.0.bn1.running_var", "backbone.7.0.bn1.num_batches_tracked", "backbone.7.0.conv2.weight", "backbone.7.0.bn2.weight", "backbone.7.0.bn2.bias", "backbone.7.0.bn2.running_mean", "backbone.7.0.bn2.running_var", "backbone.7.0.bn2.num_batches_tracked", "backbone.7.0.downsample.0.weight", "backbone.7.0.downsample.1.weight", "backbone.7.0.downsample.1.bias", "backbone.7.0.downsample.1.running_mean", "backbone.7.0.downsample.1.running_var", "backbone.7.0.downsample.1.num_batches_tracked", "backbone.7.1.conv1.weight", "backbone.7.1.bn1.weight", "backbone.7.1.bn1.bias", "backbone.7.1.bn1.running_mean", "backbone.7.1.bn1.running_var", "backbone.7.1.bn1.num_batches_tracked", "backbone.7.1.conv2.weight", "backbone.7.1.bn2.weight", "backbone.7.1.bn2.bias", "backbone.7.1.bn2.running_mean", "backbone.7.1.bn2.running_var", "backbone.7.1.bn2.num_batches_tracked", "linear.weight", "linear.bias". 

In [None]:
new_model_2.eval()
with torch.no_grad():
    loss_sum = 0
    pred_list, label_list = [], []
    for image, label in tqdm(train_dataloader, total=len(train_dataloader)):
        pred = new_model_2(image)
        loss = loss_func(pred, label)
        loss_sum += loss.item()

        pred_list.append(pred)
        label_list.append(label)

    print(f'Test loss {loss_sum / len(train_dataloader)}')

    # Calculate metrics
    final_pred = torch.concat(pred_list)
    final_pred = torch.argmax(final_pred, axis=1)
    final_label = torch.concat(label_list)
    final_label = torch.argmax(final_label, axis=1)

    print(classification_report(final_pred, final_label))

In [None]:
# Test loss 0.47746641347625796
#               precision    recall  f1-score   support

#            0       0.13      0.91      0.23       173
#            1       1.00      0.81      0.89      5571

#     accuracy                           0.81      5744
#    macro avg       0.56      0.86      0.56      5744
# weighted avg       0.97      0.81      0.87      5744