In [None]:
# header files
import numpy as np
import torch
import torch.nn as nn
import torchvision

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
np.random.seed(1234)
torch.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [None]:
# define transforms
transforms = torchvision.transforms.Compose([
                                             torchvision.transforms.Resize(256),
                                             torchvision.transforms.CenterCrop(224),
                                             torchvision.transforms.RandomHorizontalFlip(),
                                             torchvision.transforms.ToTensor(),
                                             torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
# load datasets
train_dataset = torchvision.datasets.ImageFolder("/content/drive/My Drive/train_images/", transform=transforms)
val_dataset = torchvision.datasets.ImageFolder("/content/drive/My Drive/val_images/", transform=transforms)
print(len(train_dataset))
print(len(val_dataset))

In [None]:
# dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=32, shuffle=True)

In [None]:
# vgg-19 model
model = torchvision.models.vgg19_bn(pretrained=True)

for param in model.parameters():
    param.requires_grad = False

model.classifier = nn.Sequential(
            nn.Linear(25088, 4096),
            nn.ReLU(inplace = True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace = True),
            nn.Dropout(0.5),
            nn.Linear(4096, 2)
        )

In [None]:
# loss
criterion = torch.nn.CrossEntropyLoss()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device=device)

In [None]:
# define optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=5e-4)

In [None]:
train_losses = []
train_acc = []
val_losses = []
val_acc = []

# train and validate
for epoch in range(0, 30):
    
    # train
    model.train()
    training_loss = 0.0
    total = 0
    correct = 0
    for i, (input, target) in enumerate(train_loader):
        
        input = input.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        output = model(input)
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        training_loss = training_loss + loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        
    training_loss = training_loss / float(len(train_loader))
    training_accuracy = str(100.0 * (float(correct) / float(total)))
    train_losses.append(training_loss)
    train_acc.append(training_accuracy)
    
    # validate
    model.eval()
    valid_loss = 0.0
    total = 0
    correct = 0
    for i, (input, target) in enumerate(val_loader):
        
        with torch.no_grad():
            input = input.to(device)
            target = target.to(device)

            output = model(input)
            loss = criterion(output, target)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
        valid_loss = valid_loss + loss.item()
    valid_loss = valid_loss / float(len(val_loader))
    valid_accuracy = str(100.0 * (float(correct) / float(total)))
    val_losses.append(valid_loss)
    val_acc.append(valid_accuracy)
    
    print()
    print("Epoch" + str(epoch) + ":")
    print("Training Accuracy: " + str(training_accuracy) + "    Validation Accuracy: " + str(valid_accuracy))
    print("Training Loss: " + str(training_loss) + "    Validation Loss: " + str(valid_loss))
    print()

In [None]:
import matplotlib.pyplot as plt

e = []
for index in range(0, 50):
    e.append(index)
    
plt.plot(e, train_losses)
plt.show()

In [None]:
plt.plot(e, val_losses)
plt.show()