In [1]:
from utils.annotation_util import annotation_df
from utils.dataset import CustomImageDataset
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import glob
from model.model import CustomAlexNet
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path

# Parameters

In [2]:
batch_size = 10
num_classes = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create Dataloaders

In [3]:
base_dir = "data/"
train_dir = os.path.join(base_dir, "Training")
val_dir = os.path.join(base_dir, "Validate")
test_dir = os.path.join(base_dir, "Testing")
df_train, df_val, df_test = annotation_df(train_dir, val_dir, test_dir)

In [4]:
torch.manual_seed(17)
data_aug = transforms.Compose([transforms.Resize((224,224))])

train_dataset = CustomImageDataset(df_train, transform = data_aug)
val_dataset = CustomImageDataset(df_val, transform = data_aug)
test_dataset = CustomImageDataset(df_test, transform = data_aug)

train_dataloader = DataLoader(train_dataset, batch_size = batch_size)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size)

# Model

In [5]:
criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
model = CustomAlexNet(num_classes=num_classes, loss_fn=criterion, device=device, threshold=0.5).to(device)

Using cache found in /home/longpingzhang/.cache/torch/hub/pytorch_vision_v0.10.0


In [6]:
model

CustomAlexNet(
  (pretrained_alexnet): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5, inpla

# Training

In [7]:
learning_rate = 1e-3
epochs = 10
log_dir_base = os.path.join(os.getcwd(), 'logs')
experiment_name = '1'
ckpts_til_saving = 5
start_training_from_ckpt = None
optimizer = torch.optim.Adam(
    model.parameters(), lr=learning_rate
)

In [8]:
# Create logging directory for saving model weights and summary information
log_dir = os.path.join(log_dir_base, experiment_name)
Path(log_dir).mkdir(parents=True, exist_ok=True)

model_weights_dir = os.path.join(log_dir, "checkpoints")
Path(model_weights_dir).mkdir(parents=True, exist_ok=True)

summary_dir = os.path.join(log_dir, "summary")
Path(summary_dir).mkdir(parents=True, exist_ok=True)

# Start training from a specific checkpoint
if start_training_from_ckpt:
    model = model.load_state(torch.load(start_training_from_ckpt)['model_state_dict'])

In [9]:
writer = SummaryWriter(log_dir=summary_dir ,flush_secs=20)

best_loss = float("inf")
train_losses = []
val_losses = []
epoch_num = []


for epoch in range(epochs):
    if epoch == 0:
        torch.save({'model_state_dict': model.state_dict()},
                    f'{model_weights_dir}/epoch{epoch}_before_training.pt')
    
    # Training step
    iteration = 0
    epoch_train_loss_it_cum = 0

    model.train()
    
    for batch in train_dataloader:
        optimizer.zero_grad()
        train_loss = model.training_step(batch)
        train_loss.backward()
        optimizer.step()
        epoch_train_loss_it_cum += train_loss.item()

        iteration += 1         
    epoch_train_loss = epoch_train_loss_it_cum / iteration
    train_losses.append(epoch_train_loss)
    epoch_num.append(epoch)

    # Validation step
    with torch.no_grad():
        model.eval()
        val_loss, cf_matrix = model.validation_step(val_dataloader)
        val_losses.append(val_loss.item())
        
        model.train()
    
    # Write to logs for tensorboard visualization
    writer.add_scalars('alexnet', {'training_loss': epoch_train_loss,
                                'validation_loss': val_loss}, epoch)
    
    # Save the model weights every ckpts_til_saving
    if epochs % ckpts_til_saving == 0:
        torch.save({'model_state_dict': model.state_dict()},
                   f'{model_weights_dir}/epoch{epoch}.pt')
    
    # Save the best model
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save({'model_state_dict': model.state_dict()},
                   f'{model_weights_dir}/best_model.pt')
        
    # Calculate accuracy, sensitivity, and specificity over validation set
    tn, fp, fn, tp = cf_matrix.ravel()
    accuracy = (tp + tn) / (tp + tn + fp + fn)
#     sensitivity = tp / (tp + fp)
#     specificity = tn / (tn + fn)
    
    print(f'Epcoh: {epoch}, training_loss: {epoch_train_loss}, validation_loss: {val_loss}, \
              accuracy: {accuracy}, sensitivity: {sensitivity}, specificity: {specificity} ')

  specificity = tn / (tn + fn)


Epcoh: 0, training_loss: 1459.696004603676, validation_loss: 25374.98828125,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)


Epcoh: 1, training_loss: 3053.5789935772236, validation_loss: 51465.4375,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)


Epcoh: 2, training_loss: 5134.674697305148, validation_loss: 144879.65625,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)


Epcoh: 3, training_loss: 9636.578275042686, validation_loss: 4578.7431640625,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)


Epcoh: 4, training_loss: 1151.5689180542786, validation_loss: 5734.6337890625,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)


Epcoh: 5, training_loss: 388.73219181941107, validation_loss: 73275.8671875,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)


Epcoh: 6, training_loss: 4180.958160400391, validation_loss: 10140.6728515625,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)


Epcoh: 7, training_loss: 659.8468995461097, validation_loss: 1546.4310302734375,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)


Epcoh: 8, training_loss: 1079.8160876952684, validation_loss: 1857.92041015625,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 
Epcoh: 9, training_loss: 166.6310195189256, validation_loss: 431.7682800292969,               accuracy: 0.5, sensitivity: 0.5, specificity: nan 


  specificity = tn / (tn + fn)
