In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from src.data.dataset import load_dataset
from src.models.models import SNeurodCNN

class SNeurodCNN(nn.Module):
    def __init__(self):
        super(SNeurodCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=0)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=0)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=0)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((44, 44))
        self.fc1 = nn.Linear(64 * 44 * 44, 500)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(500, 3)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = F.relu(self.conv2(x))
        x = self.adaptive_pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 44 * 44) 
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    
model = SNeurodCNN()

print(model)
summary(model, input_size=(1, 180, 180))
PATH = './models/sneurod_cnn.pth'

SNeurodCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (adaptive_pool): AdaptiveAvgPool2d(output_size=(44, 44))
  (fc1): Linear(in_features=123904, out_features=500, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=500, out_features=3, bias=True)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 178, 178]             320
         MaxPool2d-2           [-1, 32, 89, 89]               0
            Conv2d-3           [-1, 32, 87, 87]           9,248
            Conv2d-4           [-1, 64, 85, 85]          18,496
 AdaptiveAvgPool2d-5           [-1, 64, 44, 44]               0
            Linear-6                  [-1, 

In [2]:
# Optimizer: Adam
# Learning rate: 0.0001
# Epochs: 100
# Batch size: 32
# Regularizers: Early stopping (patience = 5, restore_best_weights = True)



In [3]:
import numpy as np

class EarlyStopping:
    def __init__(self, patience=5, verbose=False, path='./models/sneurod_cnn.pth'):
        self.patience = patience
        self.verbose = verbose
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} to {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [4]:
# Load and prepare data

classes = ('CN', 'AD', 'MCI')

id2label = {i: classes[i] for i in range(len(classes))}

label2id = {classes[i]: i for i in range(len(classes))}

trainset, testset, valset = load_dataset(label2id=label2id)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers = 0)

valloader = torch.utils.data.DataLoader(valset, batch_size=32, shuffle=True, num_workers = 0)

# Instanttiate the model and the EarlyStopping
model = SNeurodCNN()
#early_stopping = EarlyStopping(patience=5, verbose=True)
early_stopping = EarlyStopping(patience=5, verbose=True)

# Criterion and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)


In [5]:
# Training Loop with early stopping and progress output
for epoch in range(100): # Number of epochs
    train_loss = 0.0
    model.train()
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs.float())
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # Optionally, print batch progress every few batches

        if i % 100 == 99: # Adjust the modulus value depending on your batch size ana
            print(f'Epoch {epoch + 1}, Batch {i + 1}: Loss: {train_loss / 100:.4f}')
            train_loss = 0.0 # Reset train Loss for the next set of batches

     # Validation phase

    val_loss = 0.0
    model.eval()
    with torch.no_grad():
        for data in valloader:
            inputs, labels = data
            outputs = model(inputs.float())
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    # Print epoch-Level progress
    print(f'End of Epoch {epoch + 1}: Train Loss: {train_loss / len(trainloader): .4f}')

    # Early Stopping check

    early_stopping(val_loss, model)

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break
 
# Load the best saved model
model.load_state_dict

End of Epoch 1: Train Loss:  2.4292
Validation loss decreased (inf to 4.643683). Saving model...


KeyboardInterrupt: 