In [1]:
# header files needed
import torch
import torch.nn as nn
import torchvision
import numpy as np

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [None]:
# define transforms
train_transforms = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(30),
                                       torchvision.transforms.Resize((224, 224)),
                                       torchvision.transforms.RandomHorizontalFlip(),
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [None]:
train_data = torchvision.datasets.ImageFolder("/content/drive/My Drive/train_images/", transform=train_transforms)
val_data = torchvision.datasets.ImageFolder("/content/drive/My Drive/val_images/", transform=train_transforms)
print(len(train_data))
print(len(val_data))

In [None]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, num_workers=16, pin_memory=True, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=32, num_workers=16, pin_memory=True, shuffle=True)
print(len(train_loader))
print(len(val_loader))

In [None]:
# define network
class DenseNet_121(torch.nn.Module):
    
    # define dense block
    def dense_block(self, input_channels):
        return torch.nn.Sequential(
            torch.nn.Conv2d(input_channels, 128, kernel_size=1, bias=False),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(inplace=True),
            torch.nn.Conv2d(128, 32, kernel_size=3, padding=1, bias=False),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(inplace=True)
        )
    
    # init function
    def __init__(self, num_classes = 2):
        super(DenseNet_121, self).__init__()
        
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # dense block 1 (56 x 56 x 64)
        self.dense_block_1_1 = self.dense_block(64)
        self.dense_block_1_2 = self.dense_block(96)
        self.dense_block_1_3 = self.dense_block(128)
        self.dense_block_1_4 = self.dense_block(160)
        self.dense_block_1_5 = self.dense_block(192)
        self.dense_block_1_6 = self.dense_block(224)
        
        # transition block 1
        self.transition_block_1 = torch.nn.Sequential(
            torch.nn.Conv2d(256, 128, kernel_size=1, bias=False),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )
        
        # dense block 2 (28 x 28 x 128)
        self.dense_block_2_1 = self.dense_block(128)
        self.dense_block_2_2 = self.dense_block(160)
        self.dense_block_2_3 = self.dense_block(192)
        self.dense_block_2_4 = self.dense_block(224)
        self.dense_block_2_5 = self.dense_block(256)
        self.dense_block_2_6 = self.dense_block(288)
        self.dense_block_2_7 = self.dense_block(320)
        self.dense_block_2_8 = self.dense_block(352)
        self.dense_block_2_9 = self.dense_block(384)
        self.dense_block_2_10 = self.dense_block(416)
        self.dense_block_2_11 = self.dense_block(448)
        self.dense_block_2_12 = self.dense_block(480)
        
        
        # transition block 2
        self.transition_block_2 = torch.nn.Sequential(
            torch.nn.Conv2d(512, 256, kernel_size=1, bias=False),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )
        
        # dense block 3 (14 x 14 x 240)
        self.dense_block_3_1 = self.dense_block(256)
        self.dense_block_3_2 = self.dense_block(288)
        self.dense_block_3_3 = self.dense_block(320)
        self.dense_block_3_4 = self.dense_block(352)
        self.dense_block_3_5 = self.dense_block(384)
        self.dense_block_3_6 = self.dense_block(416)
        self.dense_block_3_7 = self.dense_block(448)
        self.dense_block_3_8 = self.dense_block(480)
        self.dense_block_3_9 = self.dense_block(512)
        self.dense_block_3_10 = self.dense_block(544)
        self.dense_block_3_11 = self.dense_block(576)
        self.dense_block_3_12 = self.dense_block(608)
        self.dense_block_3_13 = self.dense_block(640)
        self.dense_block_3_14 = self.dense_block(672)
        self.dense_block_3_15 = self.dense_block(704)
        self.dense_block_3_16 = self.dense_block(736)
        self.dense_block_3_17 = self.dense_block(768)
        self.dense_block_3_18 = self.dense_block(800)
        self.dense_block_3_19 = self.dense_block(832)
        self.dense_block_3_20 = self.dense_block(864)
        self.dense_block_3_21 = self.dense_block(896)
        self.dense_block_3_22 = self.dense_block(928)
        self.dense_block_3_23 = self.dense_block(960)
        self.dense_block_3_24 = self.dense_block(992)
        
        
        # transition block 3
        self.transition_block_3 = torch.nn.Sequential(
            torch.nn.Conv2d(1024, 512, kernel_size=1, bias=False),
            torch.nn.AvgPool2d(kernel_size=2, stride=2)
        )
        
        # dense block 4 (7 x 7 x 512)
        self.dense_block_4_1 = self.dense_block(512)
        self.dense_block_4_2 = self.dense_block(544)
        self.dense_block_4_3 = self.dense_block(576)
        self.dense_block_4_4 = self.dense_block(608)
        self.dense_block_4_5 = self.dense_block(640)
        self.dense_block_4_6 = self.dense_block(672)
        self.dense_block_4_7 = self.dense_block(704)
        self.dense_block_4_8 = self.dense_block(736)
        self.dense_block_4_9 = self.dense_block(768)
        self.dense_block_4_10 = self.dense_block(800)
        self.dense_block_4_11 = self.dense_block(832)
        self.dense_block_4_12 = self.dense_block(864)
        self.dense_block_4_13 = self.dense_block(896)
        self.dense_block_4_14 = self.dense_block(928)
        self.dense_block_4_15 = self.dense_block(960)
        self.dense_block_4_16 = self.dense_block(992)
        
        self.avgpool = torch.nn.AdaptiveAvgPool2d(7)
        
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(1024 * 7 * 7, num_classes)
        )
        
    def forward(self, x):
        x = self.features(x)
        
        # dense block 1
        x_1 = self.dense_block_1_1(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_1_2(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_1_3(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_1_4(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_1_5(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_1_6(x)
        x = torch.cat([x, x_1], 1)
        
        # transition block 1
        x = self.transition_block_1(x)
        
        # dense block 2
        x_1 = self.dense_block_2_1(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_2(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_3(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_4(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_5(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_6(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_7(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_8(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_9(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_10(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_11(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_2_12(x)
        x = torch.cat([x, x_1], 1)
        
        # transition block 2
        x = self.transition_block_2(x)
        
        # dense block 3
        x_1 = self.dense_block_3_1(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_2(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_3(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_4(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_5(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_6(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_7(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_8(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_9(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_10(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_11(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_12(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_13(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_14(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_15(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_16(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_17(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_18(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_19(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_20(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_21(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_22(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_23(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_3_24(x)
        x = torch.cat([x, x_1], 1)
        
        # transition block 3
        x = self.transition_block_3(x)
        
        # dense block 4
        x_1 = self.dense_block_4_1(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_2(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_3(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_4(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_5(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_6(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_7(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_8(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_9(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_10(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_11(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_12(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_13(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_14(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_15(x)
        x = torch.cat([x, x_1], 1)
        x_1 = self.dense_block_4_16(x)
        x = torch.cat([x, x_1], 1)
        
        x = self.avgpool(x)
        
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [None]:
# define model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseNet_121()
model.to(device)

In [None]:
# define loss for two-class problem
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# define optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9, weight_decay=0.001)

In [None]:
train_losses = []
train_acc = []
val_losses = []
val_acc = []
best_metric = -1
best_metric_epoch = -1

# train and validate
for epoch in range(0, 100):
    
    # train
    model.train()
    training_loss = 0.0
    total = 0
    correct = 0
    for i, (input, target) in enumerate(train_loader):
        
        input = input.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        training_loss = training_loss + loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
    training_loss = training_loss / float(len(train_loader))
    training_accuracy = str(100.0 * (float(correct) / float(total)))
    train_losses.append(training_loss)
    train_acc.append(training_accuracy)
    
    # validate
    model.eval()
    valid_loss = 0.0
    total = 0
    correct = 0
    for i, (input, target) in enumerate(val_loader):
        
        with torch.no_grad():
            input = input.to(device)
            target = target.to(device)

            output = model(input)
            loss = criterion(output, target)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
        valid_loss = valid_loss + loss.item()
    valid_loss = valid_loss / float(len(val_loader))
    valid_accuracy = str(100.0 * (float(correct) / float(total)))
    val_losses.append(valid_loss)
    val_acc.append(valid_accuracy)


    # store best model
    if(float(valid_accuracy) > best_metric and epoch >= 30):
      best_metric = float(valid_accuracy)
      best_metric_epoch = epoch
      torch.save(model.state_dict(), "best_model.pth")
    
    print()
    print("Epoch" + str(epoch) + ":")
    print("Training Accuracy: " + str(training_accuracy) + "    Validation Accuracy: " + str(valid_accuracy))
    print("Training Loss: " + str(training_loss) + "    Validation Loss: " + str(valid_loss))
    print()