In [None]:
from pathlib import Path
import os
from nltk.tokenize import word_tokenize
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from corpus import ColorsCorpusReader

## Prepare the data

In [None]:
def sentence2index(sentence):
    tokens = word_tokenize(sentence)
    return [vocab_dict[w] for w in tokens]

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data")
print(data_path)
corpus = ColorsCorpusReader(os.path.join(data_path,"colors.csv"), word_count=None, normalize_colors=True)
examples = list(corpus.read())
print("Number of datapoints: {}".format(len(examples)))
# balance positive and negative samples
colors_data = [e.get_context_data()[0] for e in examples]
utterance_data = [e.get_context_data()[1] for e in examples]

### Generate vocab_dict

In [None]:
from functools import reduce
import pickle
# generate vocab dict
if not os.path.exists("vocab.pkl"):
    print("Generating vocab dict ...")
    vocab_list = list(set(reduce(lambda x,y:x+y,[word_tokenize(c) for c in utterance_data]))) # with nltk.tokenizer, 3953 vocabs
    vocab_list = ["<pad>","<sos>","<eos>","<unk>"] + vocab_list                               # Added padding for batching
    vocab_dict = dict(zip(vocab_list,list(range(len(vocab_list)))))
    with open('vocab.pkl', 'wb') as f:
        pickle.dump(vocab_dict, f)
else:
    print("Loading vocab dict ...")
    with open('vocab.pkl', 'rb') as f:
        vocab_dict = pickle.load(f)
print("Length of the Vocab list is ",len(vocab_dict.keys()))
print("<pad> id = ",vocab_dict["<pad>"])
print("<sos> id = ",vocab_dict["<sos>"])
print("<eos> id = ",vocab_dict["<eos>"])
print("<unk> id = ",vocab_dict["<unk>"])
print("blue id = ",vocab_dict["blue"])
print("red id = ",vocab_dict["red"])
print("green id = ",vocab_dict["green"])

### Prepare data loader

In [None]:
# Batching
colors_data_tensor = torch.tensor(np.array(colors_data),dtype=torch.float)
context_id_data = list(map(sentence2index,utterance_data))
max_context_len = max([len(c) for c in context_id_data])
padded_context_data = torch.tensor(np.array([[1]+c+[2]+[0]*(max_context_len-len(c)) for c in context_id_data]))   # <sos>+context+<eos>+<pad>*
print("Colors shape = ",colors_data_tensor.shape)
print("Padded context id lists shape = ",padded_context_data.shape)

data = [(color,torch.tensor(context,dtype=torch.long)) for color,context in zip(colors_data_tensor,padded_context_data)]
label = torch.zeros(len(data),3)
label[:,2] = 1.0
print("total data length = ",len(data))
print("total label shape = ",label.shape)

test_split = 7000
train_data, test_data = data[:-test_split], data[-test_split:]
train_label, test_label = label[:-test_split], label[-test_split:]
print("Train, Test data length = ",len(train_data),",",len(test_data))
print("Train, Test label length = ",len(train_label),",",len(test_label))

train_dataset = list(zip(train_data,train_label))
test_dataset = list(zip(test_data,test_label))
train_batch = DataLoader(dataset=train_dataset,batch_size=128,shuffle=True,num_workers=0)
test_batch = DataLoader(dataset=test_dataset,batch_size=128,shuffle=False,num_workers=0)

## Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ",device)

In [None]:
class Simple_L0(nn.Module):
    def __init__(self,vocab_size, emb_dim=768, hidden_dim=100, output_dim=1) -> None:
        super(Simple_L0,self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = emb_dim
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.linear01 = nn.Linear(3+emb_dim,hidden_dim)
        self.linear02 = nn.Linear(hidden_dim,hidden_dim//2)
        self.linear03 = nn.Linear(hidden_dim//2, output_dim)

    def forward(self,color_rgbs, contexts):
        embs = self.embedding(contexts)
        hiddens = torch.sum(embs,dim=1).reshape(-1,self.embedding_dim)
        y1 = self.linear03(F.relu(self.linear02(F.relu(self.linear01(torch.hstack((color_rgbs[:,0],hiddens)))))))
        y2 = self.linear03(F.relu(self.linear02(F.relu(self.linear01(torch.hstack((color_rgbs[:,1],hiddens)))))))
        y3 = self.linear03(F.relu(self.linear02(F.relu(self.linear01(torch.hstack((color_rgbs[:,2],hiddens)))))))
        y_hat = F.softmax(torch.cat([y1,y2,y3],dim=1),dim=-1)
        return y_hat

## Training

### Setting

In [None]:
model = Simple_L0(len(vocab_dict)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
epoch = 100

### Start

In [None]:
train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
best_loss = 100
best_acc = 0
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0
    train_acc = 0
    test_acc = 0

    model.train()
    for (cols,context),label in train_batch:
        colors,context,label = cols.to(device), context.to(device), label.to(device)
        optimizer.zero_grad()
        y_pred = model(colors,context)
        #print(y_pred.shape,label.shape)
        loss = criterion(y_pred,label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred_label = y_pred.argmax(1)
        correct_label = label.argmax(1)
        #print(pred_label.shape,correct_label.shape)
        train_acc += (sum(pred_label==correct_label)/len(correct_label)).item()
        #break
    batch_train_loss = train_loss/len(train_batch)
    batch_train_acc = train_acc/len(train_batch)

    model.eval()
    with torch.no_grad():
        for (cols,context),label in test_batch:
            colors,context,label = cols.to(device), context.to(device), label.to(device)
            y_pred = model(colors,context)
            loss = criterion(y_pred,label)
            test_loss += loss.item()
            pred_label = y_pred.argmax(1)
            correct_label = label.argmax(1)
            test_acc += (sum(pred_label==correct_label)/len(correct_label)).item()
            #break
        batch_test_loss = test_loss/len(test_batch)
        batch_test_acc = test_acc/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    print("Train Acc:{:.2E}, Test Acc:{:.2E}".format(batch_train_acc,batch_test_acc))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    train_acc_list.append(batch_train_acc)
    test_acc_list.append(batch_test_acc)
    if batch_test_loss < best_loss:
        print("Best Loss saved: {:.2E}".format(batch_test_loss))
        torch.save(model.to(device).state_dict(),"model_params/Simple-l0_epoch="+str(epoch)+"_best-loss.pth")
        best_loss = batch_test_loss
    if batch_test_acc > best_acc:
        print("Best Acc saved: {:.2E}".format(batch_test_acc))
        torch.save(model.to(device).state_dict(),"model_params/Simple-l0_epoch="+str(epoch)+"_best-acc.pth")
        best_acc = batch_test_acc
    #break

## Evaluation

In [None]:
epoch = 100
pth = "model_params/Simple-l0_epoch="+str(epoch)+"_best-acc.pth"
model = Simple_L0(len(vocab_dict)).to(device)
model.load_state_dict(torch.load(pth,map_location=device))

In [None]:
def get_batch_acc(preds,labels):
    score = []
    for probs,label in zip(preds,labels):
        if probs[0]==probs[1] and probs[1]==probs[2]: # all same
            score.append(0.3)
        elif (probs[0]==probs[1] and max(probs)==probs[0]) or (probs[1]==probs[2] and max(probs)==probs[1]) or (probs[0]==probs[2] and max(probs)==probs[1]):
            score.append(0.5)
        else:
            predict = torch.argmax(probs)
            correct = torch.argmax(label)
            score.append(int(predict==correct))
    #print(score)
    return np.mean(score)

def accuracy(model,test_batch,device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    model.eval()
    test_acc = 0
    with torch.no_grad():
        for (cols,context),label in test_batch:
            colors,context,label = cols.to(device), context.to(device), label.to(device)
            y_pred = model(colors,context)
            test_acc += get_batch_acc(y_pred,label)
        batch_test_acc = test_acc/len(test_batch)
    return batch_test_acc

In [None]:
acc = accuracy(model,test_batch)
print("Accuracy: " + str(acc))