In [57]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader
import copy

In [58]:
# device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# dataloader
def get_data_loaders(train_dir, val_dir, batch_size=32):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    ])

    train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
    val_dataset = datasets.ImageFolder(root=val_dir, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

# model
# baseline
class BaselineEfficientNet(nn.Module):
    def __init__(self, num_classes=6):
        super(BaselineEfficientNet, self).__init__()
        self.model = models.efficientnet_b0(pretrained=True)
        for param in self.model.parameters():
            param.requires_grad = False  # freeze
        num_features = self.model.classifier[1].in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.4),
            nn.Linear(num_features, num_classes)
        )

    def forward(self, x):
        return self.model(x)

# class Model(nn.Module):
#     def __init__(self, num_classes=6):
#         super(Model, self).__init__()
#         self.feature_extractor = models.resnet50(pretrained=True)
#         self.feature_extractor = nn.Sequential(*list(self.feature_extractor.children())[:-1])
#         self.classifier = models.efficientnet_b0(pretrained=True)
#         num_features = self.classifier.classifier[1].in_features
#         self.classifier.classifier = nn.Sequential(
#             nn.Dropout(p=0.4),
#             nn.Linear(num_features, num_classes)
#         )

#     def forward(self, x):
#         with torch.no_grad():
#             features = self.feature_extractor(x)
#             features = features.view(features.size(0), -1)
#             outputs = self.classifier(features)
#         return outputs

class Model(nn.Module):
    def __init__(self, num_classes=6):
        super(Model, self).__init__()
        
        # Use ResNet50 as feature extractor
        self.feature_extractor = models.resnet50(pretrained=True)
        
        # # Modify the first layer of ResNet to accept grayscale input if needed
        # self.feature_extractor.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # Remove the last fully connected layer and add a new classifier layer
        num_features = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Sequential(
            nn.Dropout(p=0.4),
            nn.Linear(num_features, num_classes)
        )

    def forward(self, x):
        # Forward pass through ResNet
        return self.feature_extractor(x)

# train
class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=5, warmup_steps=0):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.device = device
        self.best_model_wts = copy.deepcopy(model.state_dict())
        self.best_accuracy = 0.0
        self.patience = patience
        self.warmup_steps = warmup_steps
        self.early_stop = False
        self.counter = 0
        self.best_score = None

    def train(self, num_epochs=10):
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0

            for step, (inputs, labels) in enumerate(self.train_loader):
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # warm-up steps
                if step < self.warmup_steps:
                    lr_scale = min(1.0, float(step + 1) / self.warmup_steps)
                    for pg in self.optimizer.param_groups:
                        pg['lr'] = lr_scale * self.optimizer.defaults['lr']

                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()

                running_loss += loss.item()

            train_loss = running_loss / len(self.train_loader)
            val_loss, val_accuracy = self.validate()

            print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

            # early stopping
            self._early_stopping(val_loss)
            if self.early_stop:
                print("Early stopping triggered")
                break

            if val_accuracy > self.best_accuracy:
                self.best_accuracy = val_accuracy
                self.best_model_wts = copy.deepcopy(self.model.state_dict())

        self.model.load_state_dict(self.best_model_wts)
        return self.model

    def validate(self):
        self.model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in self.val_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        return val_loss / len(self.val_loader), accuracy

    def _early_stopping(self, val_loss):
        score = -val_loss
        if self.best_score is None or score > self.best_score:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True


In [59]:
# setup
train_loader, val_loader = get_data_loaders('C:/uoft/1517/content/content/lung_ct_augmented/train', 'C:/uoft/1517/content/content/lung_ct_augmented/val')
model = Model(num_classes=5).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=1e-6)

# train model
trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=5, warmup_steps=100)

#print out first five pictures of train_loader
trained_model = trainer.train(num_epochs=20)

# save model checkpoint
torch.save(trained_model.state_dict(), 'resnet_efficientnet_finetuned.pth')


# load model
trained_model.load_state_dict(torch.load('resnet_efficientnet_finetuned.pth'))
trained_model.eval()



Epoch 1/20, Train Loss: 0.2642, Val Loss: 0.1353, Val Accuracy: 95.18%
Epoch 2/20, Train Loss: 0.1209, Val Loss: 0.1133, Val Accuracy: 96.09%
Epoch 3/20, Train Loss: 0.0756, Val Loss: 0.1498, Val Accuracy: 94.96%
EarlyStopping counter: 1 out of 5
Epoch 4/20, Train Loss: 0.0517, Val Loss: 0.1267, Val Accuracy: 95.57%
EarlyStopping counter: 2 out of 5
Epoch 5/20, Train Loss: 0.0404, Val Loss: 0.1133, Val Accuracy: 96.41%
Epoch 6/20, Train Loss: 0.0327, Val Loss: 0.1082, Val Accuracy: 96.84%
Epoch 7/20, Train Loss: 0.0291, Val Loss: 0.1110, Val Accuracy: 96.66%
EarlyStopping counter: 1 out of 5
Epoch 8/20, Train Loss: 0.0236, Val Loss: 0.3039, Val Accuracy: 92.18%
EarlyStopping counter: 2 out of 5
Epoch 9/20, Train Loss: 0.0261, Val Loss: 0.1224, Val Accuracy: 96.83%
EarlyStopping counter: 3 out of 5
Epoch 10/20, Train Loss: 0.0223, Val Loss: 0.1078, Val Accuracy: 96.99%
Epoch 11/20, Train Loss: 0.0223, Val Loss: 0.1259, Val Accuracy: 96.47%
EarlyStopping counter: 1 out of 5
Epoch 12/20, 

  trained_model.load_state_dict(torch.load('resnet_efficientnet_finetuned.pth'))


Model(
  (feature_extractor): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
  