In [140]:
from torchvision import datasets, models, transforms
from torch.utils.data import random_split, Dataset, DataLoader
from torch.optim import SGD, lr_scheduler
import torch
from torch import nn
from torch import Generator
import random, time, copy

In [117]:
ds = datasets.ImageFolder('resized')
ds.classes

['AbdomenCT', 'BreastMRI', 'CXR', 'ChestCT', 'Hand', 'HeadCT']

In [118]:
class TransformSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

def split_ds(ds, val_frac=.1, test_frac=.1, seed=0):
    n = len(ds)
    val_n = int(val_frac * n)
    test_n = int(test_frac * n)
    return random_split(
        datasets.ImageFolder('resized'), 
        [n - val_n - test_n, val_n, test_n], 
        generator=Generator().manual_seed(seed))

In [119]:
train, val, test = split_ds(ds)
print(f'train : {len(train)}\nval : {len(val)}\ntest : {len(test)}')

train : 47164
val : 5895
test : 5895


In [150]:
def train_model(train_loader, val_loader, vgg, criterion, optimizer, scheduler, num_epochs=10):
    since = time.time()
    best_model_wts = copy.deepcopy(vgg.state_dict())
    best_acc = 0.0
    
    avg_loss = 0
    avg_acc = 0
    avg_loss_val = 0
    avg_acc_val = 0
    
    train_batches = len(train_loader)
    val_batches = len(val_loader)
    
    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs))
        print('-' * 10)
        
        loss_train = 0
        loss_val = 0
        acc_train = 0
        acc_val = 0
        
        vgg.train(True)
        
        for i, data in enumerate(train_loader):
            if i % 100 == 0:
                print("\rTraining batch {}/{}".format(i, train_batches), end='', flush=True)
                
            inputs, labels = data
            
            optimizer.zero_grad()
            
            outputs = vgg(inputs)
            
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            loss_train += loss.item()
            acc_train += torch.sum(preds == labels.data)
            
            del inputs, labels, outputs, preds
            torch.cuda.empty_cache()
        
        print()
        # * 2 as we only used half of the dataset
        avg_loss = loss_train / len(train_loader)
        avg_acc = acc_train / len(train_loader)
        
        vgg.train(False)
        vgg.eval()
            
        for i, data in enumerate(val_loader):
            if i % 100 == 0:
                print("\rValidation batch {}/{}".format(i, val_batches), end='', flush=True)
                
            inputs, labels = data
            
            optimizer.zero_grad()      
            outputs = vgg(inputs)
            
            _, preds = torch.max(outputs.data, 1)
            loss = criterion(outputs, labels)
            
            loss_val += loss.item()
            acc_val += torch.sum(preds == labels.data)
            
            del inputs, labels, outputs, preds
            torch.cuda.empty_cache()
        
        avg_loss_val = loss_val / len(val_loader)
        avg_acc_val = acc_val / len(val_loader)
        
        print()
        print("Epoch {} result: ".format(epoch))
        print("Avg loss (train): {:.4f}".format(avg_loss))
        print("Avg acc (train): {:.4f}".format(avg_acc))
        print("Avg loss (val): {:.4f}".format(avg_loss_val))
        print("Avg acc (val): {:.4f}".format(avg_acc_val))
        print('-' * 10)
        print()
        
        if avg_acc_val > best_acc:
            best_acc = avg_acc_val
            best_model_wts = copy.deepcopy(vgg.state_dict())
        
    elapsed_time = time.time() - since
    print()
    print("Training completed in {:.0f}m {:.0f}s".format(elapsed_time // 60, elapsed_time % 60))
    print("Best acc: {:.4f}".format(best_acc))
    
    vgg.load_state_dict(best_model_wts)
    return vgg

In [151]:
# Build Model

vgg16 = models.vgg16(pretrained=True)

for param in vgg16.features.parameters():
    param.require_grad = False

num_features = vgg16.classifier[6].in_features
features = list(vgg16.classifier.children())[:-1] # Remove last layer
features.extend([nn.Linear(num_features, len(ds.classes))]) # Add our layer with 4 outputs
vgg16.classifier = nn.Sequential(*features) # Replace the model classifier

# Setup training

criterion = nn.CrossEntropyLoss()
optimizer_ft = SGD(vgg16.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

# Setup loaders

train_loader = DataLoader(
        TransformSubset(train, transform=transforms.Compose([transforms.ToTensor()])), batch_size=8,
        shuffle=True, num_workers=0)
val_loader = DataLoader(
        TransformSubset(val, transform=transforms.Compose([transforms.ToTensor()])), batch_size=8,
        shuffle=False, num_workers=0)

train_model(train_loader, val_loader, vgg16, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=10)

Epoch 0/10
----------
Training batch 0/5896

KeyboardInterrupt: 