In [39]:
# diagnosis of lung damage based on medical imagery
# dataset, model and training based on https://www.kaggle.com/code/fareedalianwar/chest-ctscan-pytorch-classification

In [41]:
import torch
from torch import nn
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.models import ResNet18_Weights
from torch.optim import lr_scheduler
from torchinfo import summary
from tqdm.auto import tqdm

resnet18_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [42]:
torch.cuda.is_available()

False

In [43]:
# convert data to tensor
data_transform = transforms.Compose([
    transforms.Resize(size=(224, 224)),
     transforms.RandomHorizontalFlip(p=0.7),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
]) 

In [44]:
# bring in the already imported data from the data directory
import os

data_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

train_data_path = '../data/ctimages/Data/train'
test_data_path = '../data/ctimages/Data/valid'
validation_data_path = '../data/ctimages/Data/valid'

train_data = datasets.ImageFolder(root=train_data_path, transform=data_transform)
test_data = datasets.ImageFolder(root=test_data_path, transform=data_transform)
validation_data = datasets.ImageFolder(root=validation_data_path, transform=data_transform)

len(train_data)  , len(test_data)


(613, 72)

In [45]:
# Retrieve the list of class names and the dictionary mapping class names to their indices
class_names = train_data.classes
class_dict = train_data.class_to_idx

# Output the class names and the class dictionary
print("Class Names:", class_names)
print("Class Dictionary:", class_dict)


Class Names: ['adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib', 'large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa', 'normal', 'squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa']
Class Dictionary: {'adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib': 0, 'large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa': 1, 'normal': 2, 'squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa': 3}


In [46]:
# Key mapping from old to new
key_updates = {
    'adenocarcinoma_left.lower.lobe_T2_N0_M0_Ib': 'adenocarcinoma_left.lower.lobe',
    'large.cell.carcinoma_left.hilum_T2_N2_M0_IIIa': 'large.cell.carcinoma_left.hilum',
    'squamous.cell.carcinoma_left.hilum_T1_N2_M0_IIIa': 'squamous.cell.carcinoma_left.hilum'
}

# Update the dictionary
for old_key, new_key in key_updates.items():
    if old_key in class_dict:
        class_dict[new_key] = class_dict.pop(old_key)

# Define new keys corresponding to their simplified names
new_key1 = 'adenocarcinoma_left.lower.lobe'
new_key2 = 'large.cell.carcinoma_left.hilum'
new_key3 = 'squamous.cell.carcinoma_left.hilum'

# Update specific indices in the class_names list
class_names[0] = new_key1
class_names[1] = new_key2
class_names[3] = new_key3 

print(class_dict, len(class_names))

{'normal': 2, 'adenocarcinoma_left.lower.lobe': 0, 'large.cell.carcinoma_left.hilum': 1, 'squamous.cell.carcinoma_left.hilum': 3} 4


In [47]:
# initialize the model
resnet18_model.fc = nn.Sequential (
    nn.Linear(512,512),
    nn.Dropout(0.2),
     nn.Linear(512,256),
    nn.Linear(256,len(class_names)), # -> len(labels) labels according to labels
)
resnet18_model.fc

Sequential(
  (0): Linear(in_features=512, out_features=512, bias=True)
  (1): Dropout(p=0.2, inplace=False)
  (2): Linear(in_features=512, out_features=256, bias=True)
  (3): Linear(in_features=256, out_features=4, bias=True)
)

In [48]:
resnet18_model.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [49]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18_model.parameters(), lr=0.01)
optimizer = torch.optim.SGD(resnet18_model.parameters(), lr=0.01, momentum=0.9)

summary(resnet18_model, input_size=[1, 3, 64, 64])

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 4]                    --
├─Conv2d: 1-1                            [1, 64, 32, 32]           9,408
├─BatchNorm2d: 1-2                       [1, 64, 32, 32]           128
├─ReLU: 1-3                              [1, 64, 32, 32]           --
├─MaxPool2d: 1-4                         [1, 64, 16, 16]           --
├─Sequential: 1-5                        [1, 64, 16, 16]           --
│    └─BasicBlock: 2-1                   [1, 64, 16, 16]           --
│    │    └─Conv2d: 3-1                  [1, 64, 16, 16]           36,864
│    │    └─BatchNorm2d: 3-2             [1, 64, 16, 16]           128
│    │    └─ReLU: 3-3                    [1, 64, 16, 16]           --
│    │    └─Conv2d: 3-4                  [1, 64, 16, 16]           36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 16, 16]           128
│    │    └─ReLU: 3-6                    [1, 64, 16, 16]           --
│

In [50]:
# train the model
def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer:torch.optim.Optimizer,
               device=device):
  model.train()

  train_loss, train_acc = 0, 0

  for batch, (X, y) in enumerate(dataloader):
    X, y = X.to(device), y.to(device)

    y_pred = model(X) 


    loss = loss_fn(y_pred, y)
    train_loss += loss.item()

  
    optimizer.zero_grad()

    
    loss.backward()

  
    optimizer.step()

    # Calculate accuracy metric
    y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
    train_acc += (y_pred_class==y).sum().item()/len(y_pred)
  
  # Adjust metrics to get average loss and accuracy per batch
  train_loss = train_loss / len(dataloader)
  train_acc = train_acc / len(dataloader) 
  return train_loss, train_acc 

In [51]:
def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              device=device):
  model.eval()

  test_loss, test_acc = 0,  0

  with torch.inference_mode():
    for batch, (X, y) in enumerate(dataloader): 
      X, y = X.to(device), y.to(device)

      test_pred_logits = model(X)

     
      loss = loss_fn(test_pred_logits, y)
      test_loss += loss.item()

    
      test_pred_labels = test_pred_logits.argmax(dim=1)
      test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))

  # Adjust metrics to get average loss and accuracy per batch
  test_loss = test_loss / len(dataloader)
  test_acc = test_acc / len(dataloader)
  return test_loss, test_acc

In [52]:
def train(model: torch.nn.Module,
          train_dataloader,
          test_dataloader,
          optimizer,
          loss_fn: torch.nn.Module = nn.CrossEntropyLoss(),
          epochs: int = 5, 
          device=device):
  
  results = {"train_loss": [],
             "train_acc": [],
             "test_loss": [],
             "test_acc": []}
  
  for epoch in tqdm(range(epochs)):
    train_loss, train_acc = train_step(model=model,
                                       dataloader=train_dataloader,
                                       loss_fn=loss_fn,
                                       optimizer=optimizer,
                                       device=device)
    test_loss, test_acc = test_step(model=model,
                                    dataloader=test_dataloader,
                                    loss_fn=loss_fn,
                                    device=device)
    
    print(f"Epoch: {epoch} | Train loss: {train_loss:.4f} | Train acc: {train_acc:.4f} | Test loss: {test_loss:.4f} | Test acc: {test_acc:.4f}")

    results["train_loss"].append(train_loss)
    results["train_acc"].append(train_acc)
    results["test_loss"].append(test_loss)
    results["test_acc"].append(test_acc)
  
  return results

In [53]:
NUM_EPOCHS = 15

model_0_results = train(model=resnet18_model,
                        train_dataloader=train_dataloader,
                        test_dataloader=test_dataloader,
                        optimizer=optimizer,
                        loss_fn=loss_fn,
                        epochs=NUM_EPOCHS)


  0%|          | 0/15 [00:00<?, ?it/s]

Epoch: 0 | Train loss: 0.9256 | Train acc: 0.6134 | Test loss: 1.9359 | Test acc: 0.4062
Epoch: 1 | Train loss: 0.7320 | Train acc: 0.7706 | Test loss: 1.9163 | Test acc: 0.5312
Epoch: 2 | Train loss: 0.5658 | Train acc: 0.7800 | Test loss: 1.2438 | Test acc: 0.6771
Epoch: 3 | Train loss: 0.2259 | Train acc: 0.9250 | Test loss: 0.6922 | Test acc: 0.7812
Epoch: 4 | Train loss: 0.1023 | Train acc: 0.9703 | Test loss: 1.7482 | Test acc: 0.6354
Epoch: 5 | Train loss: 0.1832 | Train acc: 0.9453 | Test loss: 0.4907 | Test acc: 0.8750
Epoch: 6 | Train loss: 0.1104 | Train acc: 0.9563 | Test loss: 0.5751 | Test acc: 0.8542
Epoch: 7 | Train loss: 0.1418 | Train acc: 0.9472 | Test loss: 0.6214 | Test acc: 0.7708
Epoch: 8 | Train loss: 0.3904 | Train acc: 0.8962 | Test loss: 0.8450 | Test acc: 0.7396
Epoch: 9 | Train loss: 0.3273 | Train acc: 0.9019 | Test loss: 0.5814 | Test acc: 0.8333
Epoch: 10 | Train loss: 0.3001 | Train acc: 0.8988 | Test loss: 0.5385 | Test acc: 0.8646
Epoch: 11 | Train lo

In [20]:
torch.save(resnet18_model.state_dict(), "../models/ct_scan_model.pth")