In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [2]:
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models

In [3]:
import os

from torch.utils.data import Dataset, DataLoader

from PIL import Image
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score

# 1. Chuẩn bị dữ liệu

- Bước 1: Đọc ảnh từ ổ cứng
- Bước 2: Biến đổi ảnh
- Bước 3: Tạo label tương ứng với ảnh

In [4]:
class MyBrainTumorDataset(Dataset):
    def __init__(self, data_folder, csv_path):
        self.data_folder = data_folder
        self.image_names = [name for name in os.listdir(data_folder) if name.endswith('.jpg')]

        # Transformation => Bước 2
        self.transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        # Label => Bước 3
        self.label_df = pd.read_csv(csv_path, usecols=['image_name', 'label'])
        self.image_names = [name for name in self.image_names if name in self.label_df.image_name.to_list()]
        
    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = self.image_names[idx]

        # Buoc 1: Doc anh tu o cung
        image = Image.open(os.path.join(self.data_folder, image_name))

        # Buoc 2: Bien doi anh
        transformed_image = self.transform(image)

        # Buoc 3: Tao label cho anh
        label_str = self.label_df[self.label_df.image_name == image_name]['label'].values[0]
        if label_str == 'tumor':
            label = 1
        else:
            label = 0

        return transformed_image, label

In [5]:
train_dataset = MyBrainTumorDataset(
    data_folder='brain_tumor_mri_dataset/train',
    csv_path='brain_tumor_mri_dataset/label.csv'
)

In [6]:
len(train_dataset)

5744

In [7]:
train_dataset[0]

(tensor([[[-0.7333, -0.2784, -0.6706,  ..., -0.7961,  0.0588, -0.5843],
          [-0.7255, -0.4118, -0.6627,  ..., -0.8275, -0.0902, -0.7020],
          [-0.9765, -0.9529, -0.9843,  ..., -0.8353, -0.1608, -0.5373],
          ...,
          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922]],
 
         [[-0.7333, -0.2784, -0.6706,  ..., -0.7961,  0.0588, -0.5843],
          [-0.7255, -0.4118, -0.6627,  ..., -0.8275, -0.0902, -0.7020],
          [-0.9765, -0.9529, -0.9843,  ..., -0.8353, -0.1608, -0.5373],
          ...,
          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9922, -0.9922, -0.9922]],
 
         [[-0.7333, -0.2784, -0.6706,  ..., -0.7961,  0.0588, -0.5843],
          [-0.7255, -0.4118,

In [8]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=0
)
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x12ba7ecf0>

# 2. Xây dựng mô hình

In [9]:
class MyCustomCNN(nn.Module):
    def __init__(self):
        super(MyCustomCNN, self).__init__()

        self.conv_1 = nn.Conv2d(
            in_channels=3,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.relu_1 = nn.ReLU()
        self.pool_1 = nn.MaxPool2d(
            kernel_size=2,
            stride=2
        )

        self.conv_2 = nn.Conv2d(
            in_channels=16,
            out_channels=32,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.relu_2 = nn.ReLU()
        self.pool_2 = nn.MaxPool2d(
            kernel_size=2,
            stride=2
        )

        self.linear_1 = nn.Linear(32 * 16 * 16, 128)
        self.relu_3 = nn.ReLU()

        self.linear_2 = nn.Linear(128, 2)
        self.softmax = nn.Softmax()

    def forward(self, x):
        x = self.conv_1(x)
        x = self.relu_1(x)
        x = self.pool_1(x)

        x = self.conv_2(x)
        x = self.relu_2(x)
        x = self.pool_2(x)
        
        x = x.view(-1, 32 * 16 * 16)
        x = self.linear_1(x)
        x = self.relu_3(x)
        
        x = self.linear_2(x)
        x = self.softmax(x)
        return x

In [10]:
model = MyCustomCNN()

In [11]:
# model.cuda()

In [12]:
model

MyCustomCNN(
  (conv_1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_1): ReLU()
  (pool_1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv_2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu_2): ReLU()
  (pool_2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (linear_1): Linear(in_features=8192, out_features=128, bias=True)
  (relu_3): ReLU()
  (linear_2): Linear(in_features=128, out_features=2, bias=True)
  (softmax): Softmax(dim=None)
)

In [13]:
class MyResNetCNN(nn.Module):
    def __init__(self):
        super(MyResNetCNN, self).__init__()
        
        self.backbone = models.resnet18(pretrained=True)
        num_features = self.backbone.fc.in_features
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
        
        self.linear = nn.Linear(num_features, 2)
        self.softmax = nn.Softmax()

    def forward(self, x):
        x = self.backbone(x)
        x = x.squeeze()
        x = self.linear(x)
        x = self.softmax(x)
        return x

In [14]:
model = MyResNetCNN()
model



MyResNetCNN(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=

In [15]:
# model.cuda()

# 3. Huấn luyện mô hình

## 3.1. Khởi tạo hàm Loss và thuật toán tối ưu Optimizer

In [16]:
loss_func = nn.CrossEntropyLoss()
loss_func

CrossEntropyLoss()

In [17]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
optimizer

SGD (
Parameter Group 0
    dampening: 0
    differentiable: False
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    momentum: 0.9
    nesterov: False
    weight_decay: 0
)

## 3.2. Huấn luyện và đánh giá mô hình

In [18]:
writer = SummaryWriter(log_dir='train_logs')
writer

<torch.utils.tensorboard.writer.SummaryWriter at 0x12bcbc980>

In [19]:
num_epoch = 10

In [20]:
ckpt_dir = 'ckpt'
os.makedirs(ckpt_dir, exist_ok=True)

In [21]:
for epoch in range(num_epoch):
    # Train model
    model.train()
    for iteration_, (image, label) in enumerate(tqdm(train_dataloader, total=len(train_dataloader))):
        
        # Bước 1: Optimizer zero grad
        optimizer.zero_grad()

        # Bước 2: Foward data to model
        # image.cuda()
        pred = model(image)

        # Bước 3: Tính giá trị loss
        loss_value = loss_func(pred, label)

        # Bước 4: Cập nhật trọng số của mô hình
        loss_value.backward()
        optimizer.step()
        
        global_iteration = epoch * len(train_dataloader) + iteration_
        writer.add_scalar('train_loss_iter', loss_value, global_iteration)

    # Bước 5: (Tuỳ chọn) In các thông số ra ngoài màn hình
    print(f'Epoch={epoch}', f'Training loss={loss_value.item()}')
    writer.add_scalar('train_loss_epoch', loss_value, epoch)

    # Evaluate model
    model.eval()
    with torch.no_grad():
        loss_sum = 0
        pred_list, label_list = [], []
        for image, label in tqdm(train_dataloader, total=len(train_dataloader)):
            # image.cuda()
            pred = model(image)
            loss = loss_func(pred, label)
            loss_sum += loss.item()

            pred_list.append(pred)
            label_list.append(label)

        print(f'Test loss {loss_sum / len(train_dataloader)}')
        writer.add_scalar('test_loss_epoch', loss_value, epoch)
        
        # Calculate metrics
        final_pred = torch.concat(pred_list)
        final_pred = torch.argmax(final_pred, axis=1)
        final_label = torch.concat(label_list)
        
        epoch_accuracy_score = accuracy_score(final_pred, final_label)
        writer.add_scalar('test_accuracy_score_epoch', epoch_accuracy_score, epoch)

        epoch_precision_score = precision_score(final_pred, final_label)
        writer.add_scalar('test_precision_score_epoch', epoch_precision_score, epoch)

        epoch_recall_score = recall_score(final_pred, final_label)
        writer.add_scalar('test_recall_score_epoch', epoch_recall_score, epoch)

        epoch_f1_score = f1_score(final_pred, final_label)
        writer.add_scalar('test_f1_score_epoch', epoch_f1_score, epoch)

        print(classification_report(final_pred, final_label))
        
        torch.save(model.state_dict(), os.path.join(ckpt_dir, f'ckpt_{epoch}.pth'))

  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:33<00:00, 30.61it/s]


Epoch=0 Training loss=0.3132641017436981


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:36<00:00, 78.97it/s]


Test loss 0.37830169294496435
              precision    recall  f1-score   support

           0       0.73      0.94      0.82       939
           1       0.99      0.93      0.96      4805

    accuracy                           0.93      5744
   macro avg       0.86      0.93      0.89      5744
weighted avg       0.94      0.93      0.94      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:33<00:00, 30.64it/s]


Epoch=1 Training loss=0.3132619857788086


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:35<00:00, 81.87it/s]


Test loss 0.36916841417135965
              precision    recall  f1-score   support

           0       0.88      0.85      0.87      1255
           1       0.96      0.97      0.96      4489

    accuracy                           0.94      5744
   macro avg       0.92      0.91      0.91      5744
weighted avg       0.94      0.94      0.94      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:31<00:00, 31.36it/s]


Epoch=2 Training loss=0.3132628798484802


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:34<00:00, 82.49it/s]


Test loss 0.38286861986918036
              precision    recall  f1-score   support

           0       0.75      0.90      0.82      1015
           1       0.98      0.94      0.96      4729

    accuracy                           0.93      5744
   macro avg       0.87      0.92      0.89      5744
weighted avg       0.94      0.93      0.93      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:35<00:00, 29.99it/s]


Epoch=3 Training loss=0.3132633566856384


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:34<00:00, 82.77it/s]


Test loss 0.36824356939923797
              precision    recall  f1-score   support

           0       0.77      0.96      0.85       966
           1       0.99      0.94      0.97      4778

    accuracy                           0.94      5744
   macro avg       0.88      0.95      0.91      5744
weighted avg       0.95      0.94      0.95      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:32<00:00, 30.93it/s]


Epoch=4 Training loss=0.31327128410339355


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:35<00:00, 80.19it/s]


Test loss 0.3613047877477072
              precision    recall  f1-score   support

           0       0.88      0.88      0.88      1210
           1       0.97      0.97      0.97      4534

    accuracy                           0.95      5744
   macro avg       0.93      0.93      0.93      5744
weighted avg       0.95      0.95      0.95      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:30<00:00, 31.60it/s]


Epoch=5 Training loss=0.3132787346839905


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:35<00:00, 81.19it/s]


Test loss 0.359987827998516
              precision    recall  f1-score   support

           0       0.78      0.99      0.87       951
           1       1.00      0.94      0.97      4793

    accuracy                           0.95      5744
   macro avg       0.89      0.97      0.92      5744
weighted avg       0.96      0.95      0.95      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:32<00:00, 30.96it/s]


Epoch=6 Training loss=0.31332260370254517


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:35<00:00, 81.02it/s]


Test loss 0.35943388148187594
              precision    recall  f1-score   support

           0       0.82      0.95      0.88      1052
           1       0.99      0.95      0.97      4692

    accuracy                           0.95      5744
   macro avg       0.91      0.95      0.93      5744
weighted avg       0.96      0.95      0.95      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:32<00:00, 31.07it/s]


Epoch=7 Training loss=0.34625330567359924


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:35<00:00, 81.36it/s]


Test loss 0.37209276182571827
              precision    recall  f1-score   support

           0       0.78      0.92      0.85      1026
           1       0.98      0.94      0.96      4718

    accuracy                           0.94      5744
   macro avg       0.88      0.93      0.90      5744
weighted avg       0.95      0.94      0.94      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:32<00:00, 30.96it/s]


Epoch=8 Training loss=0.7207817435264587


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:34<00:00, 83.18it/s]


Test loss 0.3481675073981783
              precision    recall  f1-score   support

           0       0.90      0.92      0.91      1185
           1       0.98      0.97      0.98      4559

    accuracy                           0.96      5744
   macro avg       0.94      0.95      0.95      5744
weighted avg       0.96      0.96      0.96      5744



  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [01:27<00:00, 32.91it/s]


Epoch=9 Training loss=0.31327927112579346


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:33<00:00, 85.06it/s]


Test loss 0.3415459215973081
              precision    recall  f1-score   support

           0       0.93      0.93      0.93      1207
           1       0.98      0.98      0.98      4537

    accuracy                           0.97      5744
   macro avg       0.96      0.96      0.96      5744
weighted avg       0.97      0.97      0.97      5744



# 4. Predict new data

In [22]:
state_dict = torch.load('ckpt/ckpt_0.pth', map_location='cpu')
state_dict

OrderedDict([('backbone.0.weight',
              tensor([[[[-1.3875e-02, -5.8965e-03, -2.9944e-03,  ...,  4.9242e-02,
                          1.1644e-02, -1.5179e-02],
                        [ 1.4622e-02,  1.6225e-02, -1.0648e-01,  ..., -2.7473e-01,
                         -1.3304e-01,  4.4973e-03],
                        [-3.8930e-03,  6.3580e-02,  2.9982e-01,  ...,  5.1719e-01,
                          2.5064e-01,  5.9347e-02],
                        ...,
                        [-2.7906e-02,  1.6048e-02,  6.9598e-02,  ..., -3.3980e-01,
                         -4.3255e-01, -2.6715e-01],
                        [ 3.3713e-02,  3.5409e-02,  5.8202e-02,  ...,  4.0306e-01,
                          3.8190e-01,  1.5948e-01],
                        [-1.3356e-02, -1.1038e-02, -3.0789e-02,  ..., -1.6188e-01,
                         -9.2309e-02, -1.4061e-02]],
              
                       [[-1.4811e-02, -2.6337e-02, -3.5815e-02,  ...,  2.5188e-02,
                         -4

In [23]:
# new_model_1 = MyCustomCNN()

In [24]:
new_model_2 = MyResNetCNN()



In [25]:
new_model_2.load_state_dict(state_dict)

<All keys matched successfully>

In [26]:
# new_model_1.load_state_dict(state_dict)

In [27]:
new_model_2.eval()
with torch.no_grad():
    loss_sum = 0
    pred_list, label_list = [], []
    for image, label in tqdm(train_dataloader, total=len(train_dataloader)):
        pred = new_model_2(image)
        loss = loss_func(pred, label)
        loss_sum += loss.item()

        pred_list.append(pred)
        label_list.append(label)

    print(f'Test loss {loss_sum / len(train_dataloader)}')

    # Calculate metrics
    final_pred = torch.concat(pred_list)
    final_pred = torch.argmax(final_pred, axis=1)
    final_label = torch.concat(label_list)

    print(classification_report(final_pred, final_label))

  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2872/2872 [00:33<00:00, 85.23it/s]

Test loss 0.3783016928411958
              precision    recall  f1-score   support

           0       0.73      0.94      0.82       939
           1       0.99      0.93      0.96      4805

    accuracy                           0.93      5744
   macro avg       0.86      0.93      0.89      5744
weighted avg       0.94      0.93      0.94      5744




