In [None]:
import numpy as np
import pandas as pd
import os
import time
import torch
import torch.nn as nn
import sys
import math
import random

import torchvision
from torchvision import models
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset
from torch.utils.data import random_split
from torchsummary import summary

from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image

import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score

In [None]:
AAF_TRAIN_PATH = 'drive/My Drive/AWL Internship/MixMatch/AAF_Gender_Classification/AAF_train_MixMatch.xlsx'
AAF_TEST_PATH = 'drive/My Drive/AWL Internship/MixMatch/AAF_Gender_Classification/AAF_test_MixMatch.xlsx'
!unzip -q 'drive/My Drive/AWL Internship/AAF Dataset/extracted_original-20'
AAF_IMAGE_PATH = 'extracted_original-20'

In [None]:
CUDA = 0
RANDOM_SEED = 1
LEARNING_RATE = 0.0001
NUM_EPOCHS = 5
BATCH_SIZE = 32
NUM_LABELLED = 1000	#No of labelled examples to be used in MixMatch
DEVICE = torch.device("cuda:%d" % CUDA)
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

In [None]:
class AAF_Dataset(Dataset):
	''' Custom Dataset for loading AAF Dataset images'''

	def __init__(self, csv_path, img_dir, transform):
		
		df = pd.read_excel(csv_path)
		self.img_dir = img_dir
		self.transform = transform
		self.csv_path = csv_path
		self.gender = df['Gender'].values
		self.filename = df['Image'].values
	
	#def preprocess(self):
	''' Any further preprocessing required in the data
		can be performed here'''


	def __getitem__(self, index):

		img = Image.open(os.path.join(self.img_dir,
									self.filename[index]))
		img = self.transform(img)
		y_true = self.gender[index]
		y_true = torch.tensor(y_true, dtype=torch.float32)
		
		return img, y_true

	def __len__(self):
		return self.gender.shape[0]

In [None]:
custom_transform = transforms.Compose([transforms.Resize((96,96)),
							transforms.ToTensor()])	

sample_batch_size = 4
sample_dataset = AAF_Dataset(csv_path=AAF_TRAIN_PATH, img_dir=AAF_IMAGE_PATH, transform=custom_transform)
sample_loader = DataLoader(dataset=sample_dataset, batch_size=sample_batch_size, shuffle=True)

dataiter = iter(sample_loader)
images, labels = dataiter.next()

print(images.shape)
print(labels.shape)
#print(y_true.shape)

#print(images[0])
print(images[0].shape)
print(labels[0].item())
#print(y_true[0])

def imshow(img, title):
    '''Function imshow: Helper function to display an image'''
    plt.figure(figsize=(sample_batch_size * 4, 4))
    plt.axis('off')
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.title(title)
    plt.show()

def show_batch_images(dataloader):
    '''Function show_batch_images: Helper function to display images with their true ages'''
    images, labels = next(iter(dataloader))
    
    img = torchvision.utils.make_grid(images)
    imshow(img, title = 'Images')
    print(labels)
    #print(y_true)
    #print(gender)
    #print(images)
    
    return images, labels
images, labels = show_batch_images(sample_loader)

In [None]:
AAF_train_full = AAF_Dataset(csv_path=AAF_TRAIN_PATH, img_dir=AAF_IMAGE_PATH, transform=custom_transform)
AAF_train, _ = random_split(AAF_train_full, [NUM_LABELLED, len(AAF_train_full) - NUM_LABELLED])

trainloader = DataLoader(AAF_train, batch_size=BATCH_SIZE, shuffle=True)

AAF_test = AAF_Dataset(csv_path=AAF_TEST_PATH, img_dir=AAF_IMAGE_PATH, transform=custom_transform)

testloader = DataLoader(AAF_test, batch_size= BATCH_SIZE, shuffle=True)

In [None]:
print(len(AAF_train), len(AAF_test))

In [None]:
cross_entropy = nn.CrossEntropyLoss(reduction='sum')
l2_loss = nn.MSELoss(reduction='sum')

In [None]:
model_1 = models.wide_resnet50_2(pretrained=True)
model_1.to(DEVICE)
print(summary(model_1, (3, 96, 96)))

In [None]:
class Gender_Classifier(nn.Module):
        def __init__(self):
            super(Gender_Classifier, self).__init__()

            self.fc1 = nn.Linear(1000, 100)
            self.fc2 = nn.Linear(100, 10)
            self.fc3 = nn.Linear(10,2)

        def forward(self, x):
            x = self.fc1(x)
            x = self.fc2(x)
            logits = self.fc3(x)
            return logits

In [None]:
model_2 = Gender_Classifier()
model = nn.Sequential(model_1, model_2)
model.to(DEVICE)
print(summary(model, (3,96,96)))

In [None]:
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

In [None]:
start_time = time.time()
num_batches = 0
for epoch in range(NUM_EPOCHS):

	model.train()
	for batch_idx, (x, y) in enumerate(trainloader):
		x = x.to(DEVICE)
		y = y.to(DEVICE)
		
		y_pred = model(x)
		
		loss = cross_entropy(y_pred, y.long())

		optimizer.zero_grad()

		loss.backward()
  
		optimizer.step()
		if(batch_idx % 50 == 0):
			s = ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f '
				% (epoch+1, NUM_EPOCHS, batch_idx,
					len(AAF_train)//BATCH_SIZE, loss))
			print(s)
	
	s = 'Time elapsed: %.2f min' % ((time.time() - start_time)/60)
	print(s)
	model.eval()
	y_true = []
	y_pred = []
	with torch.set_grad_enabled(False):
		correct_results = 0
		for batch_idx, (img, label) in enumerate(testloader):
			img = img.to(DEVICE)
			label = label.to(DEVICE)

			logits = model(img)
			probas = nn.functional.softmax(logits, dim=1)
			pred = torch.argmax(probas, dim=1)
			y_true.extend(label.cpu().numpy())
			y_pred.extend(pred.cpu().numpy())
	print(accuracy_score(y_true, y_pred))
	print(f1_score(y_true, y_pred))

In [None]:
model.eval()
y_true = []
y_pred = []
with torch.set_grad_enabled(False):
    correct_results = 0
    for batch_idx, (img, label) in enumerate(trainloader):
        img = img.to(DEVICE)
        label = label.to(DEVICE)

        logits = model(img)
        probas = nn.functional.softmax(logits, dim=1)
        pred = torch.argmax(probas, dim=1)
        y_true.extend(label.cpu().numpy())
        y_pred.extend(pred.cpu().numpy())
print(accuracy_score(y_true, y_pred))
print(f1_score(y_true, y_pred))

In [None]:
y_naive = np.zeros(len(y_true))

In [None]:
print(accuracy_score(y_true, y_naive))