Practical Work in AI - Concept Models
Tragler Thomas
====================


In [1]:
import sys

import pandas as pd
import os
import torch
import torch.nn as nn
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import wandb
from datetime import datetime
import torch.optim as optim
from sklearn.model_selection import KFold

import derm7pt_data
from derm7pt_data import Derm7pt_data

from importlib import reload

reload(derm7pt_data)

<module 'derm7pt_data' from 'D:\\Business\\Uni\\Practical Work\\PW_ConceptModels\\derm7pt_data.py'>

In [2]:
from derm7pt_data import Derm7pt_data

from importlib import reload
reload(derm7pt_data)

#Data loading
path = os.path.normpath('Data\\Derm7pt')

derm7pt = Derm7pt_data(path)
metadata = derm7pt.metadata
print(metadata.shape)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("device:", device)

(1011, 34)
device: cuda


In [3]:
metadata.columns

Index(['case_num', 'diagnosis', 'seven_point_score', 'pigment_network',
       'streaks', 'pigmentation', 'regression_structures', 'dots_and_globules',
       'blue_whitish_veil', 'vascular_structures',
       'level_of_diagnostic_difficulty', 'elevation', 'location', 'sex',
       'clinic', 'derm', 'diagnosis_num', 'is_cancer', 'abbrevs', 'info',
       'pigment_network_num', 'pigment_network_score', 'streaks_num',
       'streaks_score', 'pigmentation_num', 'pigmentation_score',
       'regression_structures_num', 'regression_structures_score',
       'dots_and_globules_num', 'dots_and_globules_score',
       'blue_whitish_veil_num', 'blue_whitish_veil_score',
       'vascular_structures_num', 'vascular_structures_score'],
      dtype='object')

In [4]:
#Torch CNN model with 3 Conv layers and 3 fully connected layers
class Net(nn.Module):
    def __init__(self, num_classes=1, num_concepts=1, image_size=(192, 128)):
        super(Net, self).__init__()
        #Size reduction factor of image by pooling layers
        mod = 1
        
        #conv Layers
        in_channels, out_channels = (3, 16)
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=4, stride=4, padding=0)
        mod *= 4
        
        in_channels, out_channels = (out_channels, 2*out_channels)
        self.conv2 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        mod *= 2
        
        in_channels, out_channels = (out_channels, out_channels)
        self.conv3 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        mod *= 2
        
        # Fully connected layers
        self.first_linear_layer_size = out_channels * (image_size[0]//mod * image_size[1]//mod)
        self.fc1 = nn.Linear(self.first_linear_layer_size, 256)  
        self.fc2 = nn.Linear(256, 64)
        self.fc_concepts = nn.Linear(64, num_concepts)  
        self.fc_outputs = nn.Linear(num_concepts, num_classes)  
        
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # Convolutional layers with activation and pooling
        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.relu(self.conv3(x))
        x = self.pool3(x)

        # Flatten the output for fully connected layers
        x = x.view(-1, self.first_linear_layer_size)  # Corrected input size based on spatial dimensions

        # Fully connected layers with activation
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x_concepts = self.sigmoid(self.fc_concepts(x))
        x_outputs = self.softmax(self.fc_outputs(x_concepts))

        return x_concepts, x_outputs
    
    
def majority_class_baseline(val_idx):
    print("start validation baseline: ", datetime.now())
    majority_loader = DataLoader(
        dataset=derm7pt,
        batch_size=999999,
        sampler=torch.utils.data.SubsetRandomSampler(val_idx),
    )
    for i, batch in enumerate(majority_loader, 0):
        inputs, labels, concept_labels = batch
        baseline, simple_val_baseline = majority_class_accuracy_by_labels(labels)
        
        #concept baseline
        concept_baseline = 0
        concept_outputs = torch.zeros(len(labels), num_concepts)
        concept_val_baseline = ((concept_outputs == concept_labels).sum().item()) / (len(labels)*num_concepts)
        
        print("end validation baseline:   ", datetime.now(), ", baseline: ", baseline, " percent ",  simple_val_baseline, " concept_baseline: ", concept_baseline, " concept_val_baseline: ", concept_val_baseline)
        return simple_val_baseline, concept_val_baseline
    
def majority_class_accuracy_by_labels(true_labels):
    # Find the most frequent class in the training set
    elems, counts = true_labels.unique(return_counts=True)
    majority_count = counts[counts.argmax()]
    majority_class = elems[counts.argmax()]
    #predictions = torch.full_like(true_labels, majority_class)
    accuracy = majority_count / len(true_labels)
    return majority_class, accuracy

In [7]:
#Training the model
# hyperparameters
n_epochs = 30
learning_rate = 0.01
n_folds = 8
batch_size = 64

num_classes = derm7pt.diagnosis_mapping[derm7pt.model_columns["label"]].nunique()
num_concepts = len(derm7pt.concepts_mapping)
criterion_concept = nn.BCELoss()
criterion = nn.CrossEntropyLoss() #Categorical crossEntropyLoss
print('Before Init')

wandb.init(
    # set the wandb project where this run will be logged
    project= "PracticalWork",

    # track hyperparameters and run metadata
    config={
    "learning_rate": learning_rate,
    "architecture": "SimpleCNN",
    "dataset": "derm7pt",
    "labels": derm7pt.model_columns["label"],
    "epochs": n_epochs,
    "batch_size": batch_size,
    "n_folds": n_folds,
    "device": device,
    "num_classes": num_classes,
    "num_concepts": num_concepts,
    },
    name="run"+str(datetime.now())
)



Before Init


In [8]:
# Training loop
kf = KFold(n_splits=n_folds, shuffle=True)
for fold, (train_idx, val_idx) in enumerate(kf.split(derm7pt.metadata)):
    #get the majority class of the validation set
    simple_val_baseline, concept_val_baseline = majority_class_baseline(val_idx)      
    
    train_loader = DataLoader(
        dataset=derm7pt,
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(train_idx),
    )
    val_loader = DataLoader(
        dataset=derm7pt,
        batch_size=batch_size,
        sampler=torch.utils.data.SubsetRandomSampler(val_idx),
    )
    
    # Instantiate the model
    model = Net(num_classes=num_classes,num_concepts=num_concepts, image_size=derm7pt.image_size)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(n_epochs):
        running_loss = 0.0
        running_loss_concepts = 0.0
        i = 0
        model.train()
        for i, batch in enumerate(train_loader, 0):
            inputs, labels, concept_labels = batch
            #one hot encoding of the label
            labels = torch.eye(num_classes)[labels.squeeze().int()]
            inputs, labels, concept_labels = inputs.to(device), labels.to(device), concept_labels.to(device)
    
            # Zero the parameter gradients
            optimizer.zero_grad()
            # loss for both concepts and outputs
            concept_outputs, outputs = model(inputs)
            loss_concepts = criterion_concept(concept_outputs, concept_labels)
            loss_concepts.backward(retain_graph=True)
            loss_outputs = criterion(outputs, labels)
            loss_outputs.backward()
            optimizer.step()
            
            # Print statistics
            running_loss += loss_outputs.item()
            running_loss_concepts += loss_concepts.item()
            
        running_loss /= (i+1)    
        running_loss_concepts /= (i+1)
        
        # Validation
        model.eval()
        correct = 0
        concept_correct = 0
        total = 0
        with torch.no_grad():
            for i, batch in enumerate(val_loader, 0):
                inputs, labels, concept_labels = batch
                inputs, labels, concept_labels = inputs.to(device), labels.to(device), concept_labels.to(device)
                concept_outputs, outputs = model(inputs)
                outputs = outputs.argmax(dim=1)
                concept_outputs = concept_outputs.round()
                total = labels.size(0)
                correct = (outputs == labels).sum().item()
                concept_correct = (concept_outputs == concept_labels).sum().item()
                
        val_accuracy = correct/total
        concept_val_accuracy = concept_correct/(total*num_concepts)
        wandb.log({"loss": running_loss, "concept_loss:": running_loss_concepts, "validation_accuracy": val_accuracy, "concept_val_accuracy": concept_val_accuracy})
        print('[%d, %5d] loss: %.4f, val_accuracy: %.4f, simple_baseline: %.4f, concept_0_baseline: %.4f' % (epoch + 1, i + 1, running_loss, val_accuracy, simple_val_baseline, concept_val_baseline))   
        
    #ToDo only one fold for now
    break

print('Finished Training')
wandb.finish()
print('Finished Training')

start validation baseline:  2024-04-20 20:04:44.597776
end validation baseline:    2024-04-20 20:04:45.922781 , baseline:  tensor(1)  percent  tensor(0.5276)  concept_baseline:  0  concept_val_baseline:  0.7322834645669292
tensor(0.5276) 0.7322834645669292
[1,     2] loss: 1.5940, val_accuracy: 0.2698, simple_baseline: 0.5276, concept_0_baseline: 0.7323
[2,     2] loss: 1.5310, val_accuracy: 0.3810, simple_baseline: 0.5276, concept_0_baseline: 0.7323
[3,     2] loss: 1.4577, val_accuracy: 0.5397, simple_baseline: 0.5276, concept_0_baseline: 0.7323
[4,     2] loss: 1.3964, val_accuracy: 0.5397, simple_baseline: 0.5276, concept_0_baseline: 0.7323
[5,     2] loss: 1.3757, val_accuracy: 0.5397, simple_baseline: 0.5276, concept_0_baseline: 0.7323
[6,     2] loss: 1.3591, val_accuracy: 0.4921, simple_baseline: 0.5276, concept_0_baseline: 0.7323
[7,     2] loss: 1.3522, val_accuracy: 0.5556, simple_baseline: 0.5276, concept_0_baseline: 0.7323
[8,     2] loss: 1.3506, val_accuracy: 0.4762, sim

VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded\r'), FloatProgress(value=0.15375969427897043, max=1.…

0,1
concept_loss:,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
concept_val_accuracy,▆▂▅▄▃▃▇▂▃▇▄▅▄▆▁▃▅▃▆▅▃▄▃▄▃█▃▇▅▅
loss,█▆▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation_accuracy,▁▃▆▆▆▅▇▅▆▆▆▇▆▇▅▆▇▅▆▆▆▅▇▇▅█▇▇▇▆

0,1
concept_loss:,0.52553
concept_val_accuracy,0.73016
loss,1.33212
validation_accuracy,0.50794


Finished Training
