In [None]:
from pathlib import Path
import os
from functools import reduce
from nltk.tokenize import word_tokenize
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatch
from sklearn.model_selection import train_test_split
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from corpus import ColorsCorpusReader

In [None]:
def sentence2index(sentence):
    tokens = word_tokenize(sentence)
    return [vocab_dict[w] for w in tokens]

In [None]:

def check_row_data(cols01, cols02, contexts, id=0):
    data = list(zip(cols01,cols02,contexts))
    col01,col02,context = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(4, 2))
    fig.suptitle(context)
    # plot correct color, col01
    ec = "black"
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col01, ec=ec, lw=8)
    axes[0].add_patch(patch)
    axes[0].axis('off')
    axes[0].set_title(str(col01))
    # plot wrong color, col02
    ec = col02
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col02, ec=ec, lw=8)
    axes[1].add_patch(patch)
    axes[1].axis('off')
    axes[1].set_title(str(col02))
    plt.show()

## Main test code

### prepare raw data

In [None]:
# prepare raw data
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data")
print(data_path)
corpus = ColorsCorpusReader(os.path.join(data_path,"colors.csv"), word_count=None, normalize_colors=True)
examples = list(corpus.read())
print("Number of datapoints: {}".format(len(examples)))
# balance positive and negative samples
correct_color_data = [e.get_l0_data()[0] for e in examples]
wrong_color_data = [e.get_l0_negative_data()[0] for e in examples]
context_data = [e.get_l0_data()[1][0] for e in examples]
check_row_data(correct_color_data, wrong_color_data, context_data)

### generate vocab_dict

In [None]:
import pickle
if not os.path.exists("vocab.pkl"):
    # generate vocab dict
    print("Generating vocab dict ...")
    vocab_list = list(set(reduce(lambda x,y:x+y,[word_tokenize(c) for c in context_data]))) # with nltk.tokenizer, 3953 vocabs
    vocab_list = ["<pad>"] + vocab_list                                                     # Added padding for batching
    vocab_dict = dict(zip(vocab_list,list(range(len(vocab_list)))))
    with open('vocab.pkl', 'wb') as f:
        pickle.dump(vocab_dict, f)
else:
    print("Loading vocab dict ...")
    with open('vocab.pkl', 'rb') as f:
        vocab_dict = pickle.load(f)
print("Length of the Vocab list is ",len(vocab_dict.keys()))
print("<pad> id = ",vocab_dict["<pad>"])
print("<sos> id = ",vocab_dict["<sos>"])
print("<eos> id = ",vocab_dict["<eos>"])
print("<unk> id = ",vocab_dict["<unk>"])
print("blue id = ",vocab_dict["blue"])
print("red id = ",vocab_dict["red"])
print("green id = ",vocab_dict["green"])

### Process the data to DataLoader

In [None]:
# Batching
correct_color_data_tensor = torch.tensor(np.array(correct_color_data),dtype=torch.float)
wrong_color_data_tensor = torch.tensor(np.array(wrong_color_data),dtype=torch.float)
context_id_data = list(map(sentence2index,context_data))
max_context_len = max([len(c) for c in context_id_data])
padded_context_data = np.array([c+[0]*(max_context_len-len(c)) for c in context_id_data])
print("Color01 shape = ",correct_color_data_tensor.shape)
print("Color02 shape = ",wrong_color_data_tensor.shape)
print("Padded context id lists shape = ",padded_context_data.shape)

In [None]:
# prepare training data
labels = torch.hstack((torch.ones(len(correct_color_data),dtype=torch.float),torch.zeros(len(wrong_color_data),dtype=torch.float)))
color_data = torch.vstack((correct_color_data_tensor,wrong_color_data_tensor))
context_data = np.vstack((padded_context_data,padded_context_data))
print("Shape of labels = ",labels.shape,"color data = ",color_data.shape,"context data = ",context_data.shape)
data = [(color,torch.tensor(context,dtype=torch.long)) for color,context in zip(color_data,context_data)]
print("total data length = ",len(data))

In [None]:
# Split data and comstruct dataloader
train_x, test_x, train_y, test_y = train_test_split(data, labels, train_size=0.7)
train_dataset = list(zip(train_x,train_y))
test_dataset = list(zip(test_x,test_y))
train_batch = DataLoader(dataset=train_dataset,batch_size=128,shuffle=True,num_workers=0)
test_batch = DataLoader(dataset=test_dataset,batch_size=128,shuffle=False,num_workers=0)

### Model setting

In [None]:
# training 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class SimpleBaseLine_L0(nn.Module):
    def __init__(self,vocab_size, emb_dim=768, hidden_dim=100, output_dim=1) -> None:
        super(SimpleBaseLine_L0,self).__init__()
        self.hidden_dim = hidden_dim
        self.embedding_dim = emb_dim
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.linear01 = nn.Linear(3+emb_dim,hidden_dim)
        self.linear02 = nn.Linear(hidden_dim,hidden_dim//2)
        self.linear03 = nn.Linear(hidden_dim//2, output_dim)

    def forward(self,color_rgbs, contexts):
        embs = self.embedding(contexts)
        #print(embs.shape)
        hiddens = torch.sum(embs,dim=1).reshape(-1,self.embedding_dim)
        #print(color_rgbs.shape, hiddens.shape)
        x = torch.hstack((color_rgbs,hiddens))
        x = F.relu(self.linear01(x))
        x = F.relu(self.linear02(x))
        y = self.linear03(x)
        y_hat = torch.sigmoid(y)
        return y_hat

In [None]:
model = SimpleBaseLine_L0(len(vocab_dict))
model.to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
epoch = 30

### Training

In [None]:
train_loss_list = []
test_loss_list = []
best_loss = 100
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0

    model.train()
    #print("Start Training")
    for data,label in train_batch:
        colors = data[0].to(device)
        contexts = data[1].to(device)
        label = label.to(device)
        #print(torch.sum(label))
        optimizer.zero_grad()
        y_pred = model(colors,contexts)
        #print(y_pred.shape,label.shape)
        loss = criterion(y_pred,label.view(-1,1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    batch_train_loss = train_loss/len(train_batch)

    model.eval()
    #print("Start Evaluation")
    for data,label in test_batch:
        colors = data[0].to(device)
        contexts = data[1].to(device)
        label = label.to(device)
        #print(torch.sum(label))
        y_pred = model(colors,contexts)
        loss = criterion(y_pred,label.view(-1,1))
        test_loss += loss.item()
    batch_test_loss = test_loss/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    if batch_test_loss < best_loss:
        # save
        torch.save(model.to(device).state_dict(),"model_params/baseline_fixed-vocab_l0.pth")
        best_loss = batch_test_loss

## Test accuracy

In [None]:
model = SimpleBaseLine_L0(len(vocab_dict)).to(device)
model.load_state_dict(torch.load("model_params/baseline_fixed-vocab_l0.pth",map_location=device))

In [None]:
test_correct = 0
total_data = 0
model.eval()
#print("Start Evaluation")
for data,label in test_batch:
    colors = data[0].to(device)
    contexts = data[1].to(device)
    label = label.to(device)
    #print(label)
    total_data += len(label)
    #print(torch.sum(label))
    y_pred = model(colors,contexts)
    y_class = torch.where(y_pred>0.5,1,0).view(-1)
    #print(y_class)
    test_correct += torch.sum(y_class==label).item()
print("Total number of data for this evaluatio is ",total_data)
print("Classification accuracy is ",test_correct/total_data)