# Define transforms.


In [5]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms, datasets

# Define the transform.
train_transform = transforms.Compose([
        transforms.Resize((224,224)),             # takes PIL image as input and outputs PIL image
        transforms.ToTensor(),              # takes PIL image as input and outputs torch.tensor
        transforms.Normalize(mean=[0.4280, 0.4106, 0.3589],  # takes tensor and outputs tensor
                             std=[0.2737, 0.2631, 0.2601]),  # see next step for mean and std
    ])

valid_transform = transforms.Compose([ 
        transforms.Resize((224,224)),             
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4280, 0.4106, 0.3589],
                             std=[0.2737, 0.2631, 0.2601]), 
    ])

test_transform = transforms.Compose([
        transforms.Resize((224,224)),             
#         transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4280, 0.4106, 0.3589],
                             std=[0.2737, 0.2631, 0.2601]), 
    ])

# Prepare the dataset for model training

In [6]:
from torch.utils.data import Dataset
from skimage import io
from PIL import Image
from matplotlib import cm
import pandas as pd

class dataset(Dataset):

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the frames.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        
        self.query_frame_train = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.query_frame_train) 

    

    def __getitem__(self, idx): 
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_name = os.path.join(self.root_dir,
                    self.query_frame_train.iloc[idx, 2].split("/")[len(self.query_frame_train.iloc[idx, 2].split("/"))-2],
                    self.query_frame_train.iloc[idx, 2].split("/")[len(self.query_frame_train.iloc[idx, 2].split("/"))-1])
        image = io.imread(img_name) 
        query = self.query_frame_train.iloc[idx, 0]
        score_annotations = self.query_frame_train.iloc[idx, 3:] 
        score_annotations = np.array([score_annotations])

        score_annotations = score_annotations.astype('float').reshape(-1, )
        
        sample = {'image': image, 'query': query, 'score_annotations': score_annotations}

        if self.transform:
            sample['image'] = self.transform(Image.fromarray(sample['image']))
            sample['score_annotations'] = torch.from_numpy(sample['score_annotations'])
        return sample

# Query embedding

In [7]:
def query_embedding(box_train_query, word2index_x, max_length):
    one_hot_x_list = encode_queries_index(box_train_query, word2index_x)
    one_hot_x_tensor = []
    for i in one_hot_x_list:
        one_hot_x_tensor.append(torch.FloatTensor(i))

    one_hot_x_tensor_padded = pad_sequence(one_hot_x_tensor, batch_first=True, padding_value=0)
    
    one_hot_x_tensor_padded_with_same_max_length = []
    for i in one_hot_x_tensor_padded:
        if len(i) < max_length:
            i = torch.cat((i, torch.zeros(max_length -len(i))), dim=0)
        else:
            i = i[:8]
            
        one_hot_x_tensor_padded_with_same_max_length.append(i)
        
    return torch.stack(one_hot_x_tensor_padded_with_same_max_length) # return a stack of tensors.

# Define model

In [8]:
from BasicModule import BasicModule
from torch import nn
from torch.nn import functional as F
from torchvision import models
from torchvision.models import resnet34

class QVSmodel(BasicModule):

    def __init__(self):
        super(QVSmodel, self).__init__()
        
        self.model = resnet34(pretrained='imagenet')
        self.model = models.resnet34(pretrained=True) 
        self.fc1 = torch.nn.Linear(512, 2)
        
        self.fc_text = torch.nn.Linear(8, 512)


    def forward(self, x, y):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)     
    
    
        x = F.avg_pool2d(x, 7)
        
        # reshape x
        x = x.view(x.size(0), -1)
    
        y = F.relu(self.fc_text(y))
        
        #Combine x and y by element-wise multiplication. The output dimension is still (1, 512).
        t1 = torch.mul(x, y)

        #Computes the second fully connected layer
        relevance_class_prediction = self.fc1(t1)
        
        return relevance_class_prediction

# Model training

In [9]:
import os
import time
import torch.optim as optim

def trainNet(model, batch_size, n_epochs, learning_rate):
    
    # For GPU
    net = model.cuda()
    net.train()  
    
    print("===== HYPERPARAMETERS =====")
    print("batch_size=", batch_size)
    print("epochs=", n_epochs)
    print("learning_rate=", learning_rate)
    print("=" * 30)
    
    #Get data
    data_loader = dataset
    n_batches = len(data_loader) 

    loss = torch.nn.CrossEntropyLoss()

    optimizer = optim.Adam(net.parameters(), lr = learning_rate)
    
    training_start_time = time.time()
    
    log_file = open('PATH') 
    
    for epoch in range(n_epochs):
        
        count_relevance = 0
        running_loss = 0.0
        start_time = time.time()  
        total_train_loss = 0
        
        for i_batch, sample_batched in enumerate(data_loader):
            
            labels_relevance = labels[:, 0]
            
            inputs = inputs.cuda()
            labels_relevance = labels_relevance.cuda()
    
            #Get inputs
            inputs, query, labels = sample_batched['image'], sample_batched['query'], \
            sample_batched['score_annotations']          
            
            #Wrap them in a Variable object
            inputs, labels_relevance = Variable(inputs), Variable(labels_relevance)
            
            optimizer.zero_grad()
            
            outputs = net(inputs)

            loss_size_1 = loss(outputs[0], labels_relevance.long())
            loss_size = loss_size_1
            loss_size.backward()
            optimizer.step()  
   
            running_loss += loss_size.item()
    
            total_train_loss += loss_size.item()  
           
            #Compute accuracy
            max_values_relevance, arg_maxs_relevance = torch.max(outputs[0], dim = 1)
            num_correct_relevance = torch.sum(labels_relevance.long() == arg_maxs_relevance.long())
            count_relevance = count_relevance + num_correct_relevance.item()
        
        
            print("Epoch {}, {:d}% \t train_loss_{}_batch: {:.4f} \t took: {:.4f}s".format(
                        epoch+1, int(100 * (i_batch+1) / len(data_loader)), i_batch+1, running_loss, time.time() - start_time))

            log_file.write("Epoch {}, {:d}% \t train_loss_{}_batch: {:.4f} \t took: {:.4f}s \n".format(
                        epoch+1, int(100 * (i_batch+1) / len(data_loader)), i_batch+1, running_loss, time.time() - start_time))
            log_file.flush()

            running_loss = 0.0
            start_time = time.time()

        acc_relevance = (float(count_relevance)/(len(train_loader)*199))
        print("Training accuracy_relevance = {:.4f} for epoch {}".format(acc_relevance, epoch +1))


    print("Training finished, took {:.4f}s".format(time.time() - training_start_time))
    
    log_file.close() 
    return net