In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
from PIL import Image
Image.LOAD_TRUNCATED_IMAGES = True
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from tqdm.notebook import tqdm
import os
import zipfile
from torchvision import transforms, models, datasets

In [None]:
import os
import zipfile

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
if True:
  zip_ref = zipfile.ZipFile('/content/drive/MyDrive/FGData.zip', 'r')
  zip_ref.extractall()


In [None]:
# for item in os.listdir('/content/drive/MyDrive/FGData.zip'):
#     if item.endswith('.zip'):
#         zip_ref = zipfile.ZipFile("./trial/" + item, 'r')
#         print('Extracting ' + item + ' ...')
#         zip_ref.extractall('trial')
#         zip_ref.close()
#         # os.remove("./trial/" + item)
# print("All files extracted")

In [None]:
features = 2048
fmap_size = 7

class BCNN(nn.Module):

    def __init__(self, fine_tune=False):

        super(BCNN, self).__init__()

        resnet = models.resnet50(pretrained=True)

        # freezing parameters
        if not fine_tune:

            for param in resnet.parameters():
                param.requires_grad = False
        else:

            for param in resnet.parameters():
                param.requires_grad = True

        layers = list(resnet.children())[:-2]
        self.features = nn.Sequential(*layers)

        self.fc = nn.Linear(features ** 2, 2).cuda()
        self.dropout = nn.Dropout(0.5)

        # Initialize the fc layers.
        nn.init.xavier_normal_(self.fc.weight.data)

        if self.fc.bias is not None:
            torch.nn.init.constant_(self.fc.bias.data, val=0)

    def forward(self, x):

        ## X: bs, 3, 256, 256
        ## N = bs
        N = x.size()[0]

        ## x : bs, 1024, 14, 14
        x = self.features(x)

        # bs, (1024 * 196) matmul (196 * 1024)
        x = x.view(N, features, fmap_size ** 2)
        x = self.dropout(x)

        # Batch matrix multiplication
        x = torch.bmm(x, torch.transpose(x, 1, 2))/ (fmap_size ** 2)
        x = x.view(N, features ** 2)
        x = torch.sqrt(x + 1e-5)
        x = F.normalize(x)

        x = self.dropout(x)
        x = self.fc(x)

        return x

In [None]:
def save_checkpoint(model, optimizer, train_loss, test_loss, train_acc, valid_acc, epoch, batch_idx, checkpoint_path):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'test_loss': test_loss,
        'train_acc': train_acc,
        'valid_acc': valid_acc,
        'epoch': epoch,
        'batch_idx': batch_idx
    }
    torch.save(checkpoint, checkpoint_path)

In [None]:
def train(model, criterion, optimizer, train_loader, val_loader, save_location, checkpoint_location, early_stop=4, n_epochs=10, print_every=1):
    #Initializing  variables
    valid_acc_max = 0
    stop_count = 0
    model.epochs = 0

    train_acc_list = []
    val_acc_list = []

    train_loss_list = []
    val_loss_list = []



    #Loop starts here
    for epoch in tqdm(range(n_epochs)):

        train_loss = 0
        valid_loss = 0

        train_acc = 0
        valid_acc = 0

        model.train()

        ### batch control
        batch_num = 0

        for data, label in train_loader:
            batch_num += 1
            data, label = data.cuda(), label.cuda()
            output = model(data)

            loss = criterion(output, label)
            optimizer.zero_grad()

            loss.backward()
            optimizer.step()

            # Track train loss by multiplying average loss by number of examples in batch
            train_loss += loss.item() * data.size(0)


            # Calculate accuracy by finding max log probability
            # first output gives the max value in the row(not what we want), second output gives index of the highest val
            _, pred = torch.max(output, dim=1)

            # using the index of the predicted outcome above, torch.eq() will check prediction index against label index to see if prediction is correct(returns 1 if correct, 0 if not)
            correct_tensor = pred.eq(label.data.view_as(pred))

            #tensor must be float to calc average
            accuracy = torch.mean(correct_tensor.type(torch.FloatTensor))
            train_acc += accuracy.item() * data.size(0)

            if batch_num%10 == 0:
                print(f'Epoch: {epoch}\t{100 * (batch_num + 1) / len(train_loader):.2f}% complete.')

        model.epochs += 1

        if val_loader is not None:
            with torch.no_grad():

                model.eval()

                for data, label in val_loader:

                    data, label = data.cuda(), label.cuda()
                    output = model(data)
                    loss = criterion(output, label)
                    valid_loss += loss.item() * data.size(0)

                    _, pred = torch.max(output, dim=1)
                    correct_tensor = pred.eq(label.data.view_as(pred))
                    accuracy = torch.mean(correct_tensor.type(torch.FloatTensor))
                    valid_acc += accuracy.item() * data.size(0)

            valid_loss = valid_loss / len(val_loader.dataset)
            valid_acc = valid_acc / len(val_loader.dataset)

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_acc / len(train_loader.dataset)

        train_acc_list.append(train_acc)
        train_loss_list.append(train_loss)
        val_acc_list.append(valid_acc)
        val_loss_list.append(valid_loss)



        save_checkpoint(model, optimizer, train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch, batch_num, checkpoint_location)


        if (epoch + 1) % print_every == 0:

            print(f'\nEpoch: {epoch} \tTraining Loss: {train_loss:.4f} \tValidation Loss: {valid_loss:.4f}')
            print(f'\t\tTraining Accuracy: {100 * train_acc:.2f}%\t Validation Accuracy: {100 * valid_acc:.2f}%')

            if valid_acc > valid_acc_max:

                torch.save({
                    'state_dict': model.state_dict()
                }, save_location)

                stop_count = 0
                valid_acc_max = valid_acc
                best_epoch = epoch

            else:

                stop_count += 1

                # Below is the case where we handle the early stop case
                if stop_count >= early_stop:

                    print(f'\nEarly Stopping Total epochs: {epoch}. Best epoch: {best_epoch} with best val acc: {100 * valid_acc_max:.2f}%')
                    model.load_state_dict(torch.load(save_location)['state_dict'])
                    model.optimizer = optimizer
                    return model

    model.optimizer = optimizer

    return model

# Loading the datasets from the files

In [None]:
FGData = datasets.ImageFolder(root = './FGData/', transform = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()]))

In [None]:
# Splitting the data into train and test
TRAIN_PCT = 0.8
train_size = int(TRAIN_PCT * len(FGData))
test_size = len(FGData) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(FGData, [train_size, test_size])

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True)

In [None]:
# Creating the model
model = BCNN().cuda()

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 327MB/s]


In [None]:
# Defining the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [None]:
# Training the model
model = train(model, criterion, optimizer, train_loader, test_loader,'best_bs128.pth', 'checkpoint_bs128.pth', n_epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 0	15.94% complete.
Epoch: 0	30.43% complete.
Epoch: 0	44.93% complete.
Epoch: 0	59.42% complete.
Epoch: 0	73.91% complete.
Epoch: 0	88.41% complete.

Epoch: 0 	Training Loss: 0.6362 	Validation Loss: 0.5955
		Training Accuracy: 63.00%	 Validation Accuracy: 66.27%
Epoch: 1	15.94% complete.
Epoch: 1	30.43% complete.
Epoch: 1	44.93% complete.
Epoch: 1	59.42% complete.
Epoch: 1	73.91% complete.
Epoch: 1	88.41% complete.

Epoch: 1 	Training Loss: 0.5867 	Validation Loss: 0.5847
		Training Accuracy: 67.55%	 Validation Accuracy: 67.55%
Epoch: 2	15.94% complete.
Epoch: 2	30.43% complete.
Epoch: 2	44.93% complete.
Epoch: 2	59.42% complete.
Epoch: 2	73.91% complete.
Epoch: 2	88.41% complete.

Epoch: 2 	Training Loss: 0.5639 	Validation Loss: 0.5968
		Training Accuracy: 69.02%	 Validation Accuracy: 66.68%
Epoch: 3	15.94% complete.
Epoch: 3	30.43% complete.
Epoch: 3	44.93% complete.
Epoch: 3	59.42% complete.
Epoch: 3	73.91% complete.
Epoch: 3	88.41% complete.

Epoch: 3 	Training Loss: 0.544

In [None]:
chkpt = torch.load('./best_bs128.pth')

In [None]:
chkpt.keys()

dict_keys(['state_dict'])

In [None]:
model.load_state_dict(chkpt['state_dict'])

<All keys matched successfully>

In [None]:
model = train(model, criterion, optimizer, train_loader, test_loader,'best_bs128_10.pth', 'checkpoint_bs128_10.pth', n_epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 0	15.94% complete.
Epoch: 0	30.43% complete.
Epoch: 0	44.93% complete.
Epoch: 0	59.42% complete.
Epoch: 0	73.91% complete.
Epoch: 0	88.41% complete.

Epoch: 0 	Training Loss: 0.5028 	Validation Loss: 0.5343
		Training Accuracy: 74.14%	 Validation Accuracy: 71.77%
Epoch: 1	15.94% complete.
Epoch: 1	30.43% complete.
Epoch: 1	44.93% complete.
Epoch: 1	59.42% complete.
Epoch: 1	73.91% complete.
Epoch: 1	88.41% complete.

Epoch: 1 	Training Loss: 0.5281 	Validation Loss: 0.5521
		Training Accuracy: 71.92%	 Validation Accuracy: 70.95%
Epoch: 2	15.94% complete.
Epoch: 2	30.43% complete.
Epoch: 2	44.93% complete.
Epoch: 2	59.42% complete.
Epoch: 2	73.91% complete.
Epoch: 2	88.41% complete.

Epoch: 2 	Training Loss: 0.4944 	Validation Loss: 0.5776
		Training Accuracy: 74.14%	 Validation Accuracy: 68.64%
Epoch: 3	15.94% complete.
Epoch: 3	30.43% complete.
Epoch: 3	44.93% complete.
Epoch: 3	59.42% complete.
Epoch: 3	73.91% complete.
Epoch: 3	88.41% complete.

Epoch: 3 	Training Loss: 0.490

In [None]:
from google.colab import files

In [None]:
files.download('./best_bs128_10.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>