In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from corpus import ColorsCorpusReader

In [None]:
GPT = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ",device)

## Prepare data

In [None]:
def check_row_data(cols, contexts, id=0):
    col01,col02,col03 = cols[id]
    context = contexts[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 3))
    fig.suptitle(context)
    # plot correct color, col01
    ec = col01
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col01, ec=ec, lw=8)
    axes[0].add_patch(patch)
    axes[0].axis('off')
    #axes[0].set_title(str(col01))
    # plot wrong color, col02
    ec = col02
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col02, ec=ec, lw=8)
    axes[1].add_patch(patch)
    axes[1].axis('off')
    #axes[1].set_title(str(col02))
    # plot wrong color, col02
    ec = "black"
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col03, ec=ec, lw=8)
    axes[2].add_patch(patch)
    axes[2].axis('off')
    #axes[2].set_title(str(col03))
    plt.show()

In [None]:
from nltk.tokenize import word_tokenize
def sentence2index(sentence):
    tokenized = tokenizer.encode(sentence)
    #print(tokenized)
    return tokenized

In [None]:
# prepare raw data
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data")
print(data_path)
corpus = ColorsCorpusReader(os.path.join(data_path,"colors.csv"), word_count=None, normalize_colors=True)
examples = list(corpus.read())
print("Number of datapoints: {}".format(len(examples)))
# balance positive and negative samples
colors_data = [e.get_context_data()[0] for e in examples]
utterance_data = [e.get_context_data()[1] for e in examples]
check_row_data(colors_data,utterance_data,id=10)

### Generating vocab_dict

In [None]:
from functools import reduce
import pickle
from transformers import GPT2Tokenizer

# generate vocab dict
if not os.path.exists("vocab.pkl"):
    print("Generating vocab dict ...")
    vocab_list = list(set(reduce(lambda x,y:x+y,[word_tokenize(c) for c in utterance_data]))) # with nltk.tokenizer, 3953 vocabs
    vocab_list = ["<pad>","<sos>","<eos>","<unk>"] + vocab_list                               # Added padding for batching
    vocab_dict = dict(zip(vocab_list,list(range(len(vocab_list)))))
    with open('vocab.pkl', 'wb') as f:
        pickle.dump(vocab_dict, f)
elif not GPT:
    print("Loading vocab dict ...")
    with open('vocab.pkl', 'rb') as f:
        vocab_dict = pickle.load(f)
else:
    print("Generating Vocab_dict from GPT tokenizer ...")
    tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
    gpt_vocab_dict = tokenizer.get_vocab()
    print("Loading vocab dict ...")
    with open('vocab.pkl', 'rb') as f:
        vocab_dict = pickle.load(f)

print("Length of the Vocab list is ",len(gpt_vocab_dict.keys()))
if not GPT:
    print("<pad> id = ",vocab_dict["<pad>"])
    print("<sos> id = ",vocab_dict["<sos>"])
    print("<eos> id = ",vocab_dict["<eos>"])
    print("<unk> id = ",vocab_dict["<unk>"])
    print("blue id = ",vocab_dict["blue"])
    print("red id = ",vocab_dict["red"])
    print("green id = ",vocab_dict["green"])
    PAD = 0
    SOS = 1
    EOS = 2
    UNK = 3
else:
    print("PAD id = ",gpt_vocab_dict["pad"])
    print("BOS id = ",gpt_vocab_dict["<|endoftext|>"])
    print("EOS id = ",gpt_vocab_dict["<|endoftext|>"])
    print("UNK id = ",gpt_vocab_dict["<|endoftext|>"])
    print("blue id = ",gpt_vocab_dict["blue"])
    print("red id = ",gpt_vocab_dict["red"])
    print("green id = ",gpt_vocab_dict["green"])
    PAD = 15636
    SOS= EOS = UNK = 50256
    original_PAD = 0
    original_SOS = 1
    original_EOS = 2
    original_UNK = 3

w2i = vocab_dict
i2w = {k:v for (v,k) in vocab_dict.items()}

### Prepapre the data loader

In [None]:
colors_data_tensor = torch.tensor(np.array(colors_data),dtype=torch.float)
context_id_data = list(map(sentence2index,utterance_data))
max_context_len = np.max([len(c) for c in context_id_data])
print("MAX length = ",max_context_len)
padded_context_data = torch.tensor(np.array([[SOS]+c+[EOS]+[PAD]*(max_context_len-len(c)) for c in context_id_data]))   # <sos>+context+<eos>+<pad>*
print("Colors shape = ",colors_data_tensor.shape)
print("Padded context id lists shape = ",padded_context_data.shape)

data = [(color,torch.tensor(context,dtype=torch.long)) for color,context in zip(colors_data_tensor,padded_context_data)]
label = torch.zeros(len(data),3)
label[:,2] = 1.0
print("total data length = ",len(data))
print("total label shape = ",label.shape)

test_split = 7000
train_data, test_data = data[:-test_split], data[-test_split:]
train_label, test_label = label[:-test_split], label[-test_split:]
print("Train, Test data length = ",len(train_data),",",len(test_data))
print("Train, Test label length = ",len(train_label),",",len(test_label))

train_dataset = list(zip(train_data,train_label))
test_dataset = list(zip(test_data,test_label))
train_batch = DataLoader(dataset=train_dataset,batch_size=32,shuffle=True,num_workers=0)
test_batch = DataLoader(dataset=test_dataset,batch_size=32,shuffle=False,num_workers=0)

## Model

### Encoder

In [None]:
# For permutation invariance
class Colors_DeepSet(nn.Module):
    def __init__(self, input_size=3, output_size=16):
        super(Colors_DeepSet, self).__init__()
        self.linear1 = nn.Linear(input_size, output_size)
        self.linear2 = nn.Linear(output_size, output_size)
    
    def forward(self, col1, col2):
        col_emb1 = F.relu(self.linear1(col1))
        col_emb2 = F.relu(self.linear1(col2))
        col_embs = col_emb1 + col_emb2
        col_embs = self.linear2(col_embs)
        return col_embs

In [None]:
class Colors_Feature_Encoder(nn.Module):
    def __init__(self, input_size=3, hidden_size=16):
        super(Colors_Feature_Encoder, self).__init__()
        self.deepset_size = 16
        self.deepset = Colors_DeepSet(input_size, self.deepset_size)
        self.linear = nn.Linear(input_size+self.deepset_size, hidden_size)

    def forward(self,feats,labels):
        idxs = [0,1,2]
        target_idx = int(torch.argmax(labels))
        idxs.remove(target_idx)
        other_idx1,other_idx2 = idxs[0],idxs[1]
        target_col,col1,col2 = feats[:,target_idx], feats[:,other_idx1], feats[:,other_idx2]
        col_embs = F.relu(self.deepset(col1,col2))
        cols = torch.hstack((target_col,col_embs))
        feat = self.linear(cols)
        return feat

### Speaker Encoder-Decoder model

In [None]:
# Define the encoder-decoder model
from transformers import EncoderDecoderModel

class S1_EncoderDecoder(torch.nn.Module):
    def __init__(self, input_size, hidden_size=768):
        super(S1_EncoderDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.encoder = Colors_Feature_Encoder(input_size, hidden_size)
        self.decoder = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "distilgpt2").decoder

    def forward(self, feats, labels, langs):
        batch_size = len(feats)
        encoder_hidden = self.encoder(feats, labels)
        decoder_hidden = encoder_hidden.reshape(batch_size,1,self.hidden_size)
        decoder_input = langs[:,:-1]
        decoder_output = self.decoder(input_ids=decoder_input, encoder_hidden_states=decoder_hidden)
        return decoder_output[0]
    
    def generate(self,feats,labels,max_len=5,temperature=0.7):
        batch_size = len(feats)
        encoder_hidden = self.encoder(feats, labels)
        decoder_hidden = encoder_hidden.reshape(batch_size,1,self.hidden_size)
        sos = "<|endoftext|>"
        generated = torch.tensor(tokenizer.encode(sos)*batch_size).unsqueeze(1).to(decoder_hidden.device)
        probs_list = torch.zeros(batch_size,50257)
        probs_list[:,SOS] = 1.0
        probs_list = probs_list.unsqueeze(1).to(decoder_hidden.device)
        for i in range(max_len):
            #print(generated.shape)
            decoder_output = self.decoder(input_ids=generated, encoder_hidden_states=decoder_hidden)
            logits = decoder_output[0][:,-1,:]/temperature
            probs = F.softmax(logits,dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
            probs_list = torch.cat((probs_list,probs.unsqueeze(1)),dim=1)
        return generated,probs_list

## Training

### Setting

In [None]:
from literal_listener_color import Emb_RNN_L0

speaker = S1_EncoderDecoder(input_size=3)
speaker.to(device)

# First freeze all the weights
base_model = speaker.decoder.get_output_embeddings()
for param in base_model.parameters():
    param.requires_grad = False
# The release the final one layers out of 6 layers
last_two_layers = speaker.decoder.transformer.h[-1:]
for layer in last_two_layers:
    for param in layer.parameters():
        param.requires_grad = True

literal_listener = Emb_RNN_L0(len(vocab_dict)).to(device)
literal_listener.load_state_dict(torch.load("model_params\emb-rnn-l0_epoch=100_best-acc.pth",map_location=device))

optimizer = optim.Adam(speaker.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=PAD)

max_len = 5

epoch = 10

In [None]:
def check_data(cols, label, c_lang, g_lang, lis_label):
    col01,col02,col03 = cols[0].numpy(), cols[1].numpy(), cols[2].numpy()
    context = "Correct: "+" ".join(c_lang)+"\nGenerated:"+" ".join(g_lang)+"\n "
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(4, 6))
    fig.suptitle(context)
    # plot correct color, col01
    ec = "black" if label[0] > 0.5 else col01
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col01, ec=ec, lw=8)
    axes[0].add_patch(patch)
    axes[0].axis('off')
    if torch.argmax(lis_label)==0:axes[0].set_title("Listener Prediction")
    # plot wrong color, col02
    ec = "black" if label[1] > 0.5 else col02
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col02, ec=ec, lw=8)
    axes[1].add_patch(patch)
    axes[1].axis('off')
    if torch.argmax(lis_label)==1:axes[1].set_title("Listener Prediction")
    # plot wrong color, col02
    ec = "black" if label[2] > 0.5 else col03
    patch = mpatch.Rectangle((0, 0), 1, 1, color=col03, ec=ec, lw=8)
    axes[2].add_patch(patch)
    axes[2].axis('off')
    if torch.argmax(lis_label)==2:axes[2].set_title("Listener Prediction")
    plt.show()

### Start

In [None]:
from tqdm import tqdm

train_loss_list = []
test_loss_list = []
test_l0_acc_list = []
best_loss = 100
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0
    train_acc = 0
    test_acc = 0

    literal_listener.eval()
    speaker.train()
    for (cols,lang),label in tqdm(train_batch):
        cols, lang, label = cols.to(device), lang.to(device), label.to(device)
        optimizer.zero_grad()
        output = speaker(cols, label, lang)
        # Compute the loss
        output_view = output.view(-1, output.shape[-1])
        target = lang[:,1:].reshape(-1)
        lang_loss = criterion(output_view, target)
        lang_loss.backward(retain_graph=True)
        optimizer.step()
        train_loss += lang_loss.item()
        #break
    batch_train_loss = train_loss/len(train_batch)

    accs = []
    speaker.eval()
    with torch.no_grad():
        for (cols,lang),label in tqdm(test_batch):
            cols, lang, label = cols.to(device), lang.to(device), label.to(device)
            output = speaker(cols, label, lang)
            output_view = output.view(-1, output.shape[-1])
            target = lang[:,1:].reshape(-1)
            lang_loss = criterion(output_view, target)
            test_loss += lang_loss.item()
            generated,lang_probs = speaker.generate(cols,label,max_len=max_len)
            #break
        batch_test_loss = test_loss/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    if batch_test_loss < best_loss:
        print("Best loss saved ...")
        torch.save(speaker.to(device).state_dict(),"model_params/color_S1-GPT2-decoder_loss=Lang_best-loss_freeze-1.pth")
        best_loss = batch_test_loss
    if i%1 == 0:
        id = np.random.randint(len(cols))
        cols = cols[id].to("cpu")
        c_langs = tokenizer.decode([idx for idx in lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        g_langs = tokenizer.decode([idx for idx in output.argmax(2)[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        lis_label = torch.zeros(3)
        lis_label[2] = 1.0
        check_data(cols, label[id], c_langs, g_langs, lis_label)
        print(tokenizer.decode(generated[id]))
    #break

## Accuracy

In [None]:
from literal_listener_color import Emb_RNN_L0
from tqdm import tqdm

speaker = S1_EncoderDecoder(input_size=3).to(device)
best = "model_params/color_S1-GPT2-decoder_loss=Lang_best-loss_freeze-1.pth"
speaker.load_state_dict(torch.load(best,map_location=device))
speaker.to(device)
print("Model loaded ...")

literal_listener = Emb_RNN_L0(len(vocab_dict)).to(device)
literal_listener.load_state_dict(torch.load("model_params\emb-rnn-l0_epoch=100_best-acc.pth",map_location=device))

optimizer = optim.Adam(speaker.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()

max_len = 5


### L0 accuracy

In [None]:
from nltk.tokenize import word_tokenize
def gpt_lang2L0_lang(generated_langs):
    langs = [tokenizer.decode([idx for idx in generated if idx not in [PAD,SOS,EOS]]) for generated in generated_langs]
    tokens = [[w2i.get(w,original_UNK) for w in word_tokenize(l)] for l in langs]
    max_tokens_len = max([len(t) for t in tokens])
    padded_tokens = torch.tensor(np.array([[original_SOS]+ts+[original_EOS]+[original_PAD]*(max_tokens_len-len(ts)) for ts in tokens]))
    return padded_tokens

In [None]:
accs = []
speaker.eval()
#speaker = speaker.cpu()
with torch.no_grad():
    for i,((cols,lang),label) in enumerate(test_batch):
        cols, lang, label = cols.to(device), lang.to(device), label.to(device)
        generated_lang, lang_probs = speaker.generate(cols,label,max_len=max_len)
        output_lang = gpt_lang2L0_lang(generated_lang).to(device)
        literal_listener.eval()
        #literal_listener = literal_listener.cpu()
        lis_labels = literal_listener(cols, output_lang)
        pred_labels = torch.argmax(lis_labels,dim=1)
        correct_labels = torch.zeros(cols.shape[0])+2
        acc = sum(correct_labels.to(device)==pred_labels)/len(correct_labels)
        accs.append(acc.item())
        # if i%100 == 0:
        #     print(i+1,"/",len(test_batch))
        #     id = np.random.randint(len(cols))
        #     cols = cols[id].to("cpu")
        #     c_langs = tokenizer.decode([idx for idx in lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        #     g_langs = tokenizer.decode([idx for idx in generated_lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        #     label = label[id]
        #     #print(int(torch.argmax(lis_labels[id])))
        #     check_data(cols, label, c_langs, g_langs, lis_labels[id])

print("Accuracy: ",np.mean(accs))

### L1 accuracy

In [None]:
losss = []
accs = []
speaker.eval()

for i,((cols,lang),label) in enumerate(test_batch):
    cols, lang, label = cols.to(device).to(torch.float), lang.to(device), label.to(device).to(torch.float)
    # for 1st image
    label01 = torch.zeros_like(label)
    label01[:,0] = 1.0
    generated_lang1, lang_probs1 = speaker.generate(cols,label01,max_len=max_len)
    # for 2nd image
    label02 = torch.zeros_like(label)
    label02[:,1] = 1.0
    generated_lang2, lang_probs2 = speaker.generate(cols,label02,max_len=max_len)
    # for 3rd image
    label03 = torch.zeros_like(label)
    label03[:,2] = 1.0
    generated_lang3, lang_probs3 = speaker.generate(cols,label03,max_len=max_len)
    
    # compute the probability
    prob01 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_probs1,lang))]
    prob01_sums = list(map(sum,prob01))
    prob02 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_probs2,lang))]
    prob02_sums = list(map(sum,prob02))
    prob03 = [[torch.log(word_dist[idx]+0.001).to("cpu").detach() for word_dist,idx in zip(sent,idxs)] \
              for batch,(sent,idxs) in enumerate(zip(lang_probs3,lang))]
    prob03_sums = list(map(sum,prob03))
    probs = F.softmax(torch.tensor(np.array([prob01_sums,prob02_sums,prob03_sums])).transpose(0,1),dim=-1)
    #print(probs)
    loss = criterion(probs.to(device),label)
    losss.append(loss.item())
    pred_labels = torch.argmax(probs,dim=1)
    correct_labels = torch.zeros(cols.shape[0])+2
    acc = sum(correct_labels==pred_labels)/len(correct_labels)
    accs.append(acc.item())
    
    if i%40 == 0:
        id = np.random.randint(len(cols))
        imgs = cols[id].to("cpu")
        c_langs = tokenizer.decode([idx for idx in lang[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        g_langs = tokenizer.decode([idx for idx in generated_lang1[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")\
                + ["|"]+ tokenizer.decode([idx for idx in generated_lang2[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ") \
                + ["|"]+ tokenizer.decode([idx for idx in generated_lang3[id].to("cpu").tolist() if idx not in [PAD,SOS,EOS]]).split(" ")
        label = label[id]
        check_data(imgs, label, c_langs, g_langs, probs[id])
        print(torch.exp(probs[id]))
        print(torch.where(torch.exp(probs[id])>0.1,1,0))

print("Loss: ",np.mean(losss))
print("Accuracy: ",np.mean(accs))