In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer
from transformers import BertModel
import torch
from torch import nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from corpus import ColorsCorpusReader

## Prepare BERT

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device = ",device)

In [None]:
# Pretrained Bert word embedding model 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') # casedは大文字小文字区別なし
bert_model = BertModel.from_pretrained('bert-base-uncased',output_hidden_states=True)
emb_dim = 768

In [None]:
def sentence2vector(sentence):
    print(sentence)
    marked_sents = "[CLS] "+sentence+" [SEP]"
    tokens = tokenizer.tokenize(marked_sents)
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokens)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)
    bert_model.to(device)
    bert_model.eval()
    with torch.no_grad(): outputs = bert_model(tokens_tensor)
    vecs = outputs[0]
    return vecs[0],tokens

## Main experiments

### Prepare data

In [None]:
# prepare raw data
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data")
print(data_path)
corpus = ColorsCorpusReader(os.path.join(data_path,"colors.csv"), word_count=None, normalize_colors=True)
examples = list(corpus.read())
print("Number of datapoints: {}".format(len(examples)))
# balance positive and negative samples
correct_color_data = [e.get_l0_data()[0] for e in examples]
wrong_color_data = [e.get_l0_negative_data()[0] for e in examples]
context_data = [e.get_l0_data()[1][0] for e in examples]

### Create dataloader and batch

In [None]:
# Batching
correct_color_data_tensor = torch.tensor(np.array(correct_color_data),dtype=torch.float)
wrong_color_data_tensor = torch.tensor(np.array(wrong_color_data),dtype=torch.float)
if not os.path.exists("tmp/all_contexts_embs_46994x33x768.tensor"):
    context_vecs = [c[0] for c in list(map(sentence2vector,context_data))]
    max_context_len = max([len(c) for c in context_vecs])
    padded_context_data = torch.vstack(tuple([torch.vstack((c.to("cpu"),torch.zeros(max_context_len-len(c),emb_dim))) for c in context_vecs]))
    torch.save(padded_context_data,"tmp/all_contexts_embs_46994x33x768.tensor")
else:
    print("Context Padded Tensor is loaded...")
    padded_context_data = torch.load("tmp/all_contexts_embs_46994x33x768.tensor")
    padded_context_data = padded_context_data.view(-1,33,emb_dim)
print("Color01 shape = ",correct_color_data_tensor.shape)
print("Color02 shape = ",wrong_color_data_tensor.shape)
print("Padded context id lists shape = ",padded_context_data.shape)

In [None]:
# prepare training data
labels = torch.hstack((torch.ones(len(correct_color_data),dtype=torch.float),torch.zeros(len(wrong_color_data),dtype=torch.float)))
color_data = torch.vstack((correct_color_data_tensor,wrong_color_data_tensor))
token_sum_vecs = torch.vstack(tuple([torch.sum(vecs,dim=0) for vecs in padded_context_data]))
sum_context_data = torch.vstack((token_sum_vecs,token_sum_vecs))
print("Shape of labels = ",labels.shape,"color data = ",color_data.shape,"context data = ",sum_context_data.shape)
# create data label pair
sum_data = [(color,torch.tensor(context,dtype=torch.float)) for color,context in zip(color_data,sum_context_data)]
print("total data length = ",len(sum_data))

In [None]:
# Split data and comstruct dataloader
train_x, test_x, train_y, test_y = train_test_split(sum_data, labels, train_size=0.7)
train_dataset = list(zip(train_x,train_y))
test_dataset = list(zip(test_x,test_y))
train_batch = DataLoader(dataset=train_dataset,batch_size=128,shuffle=True,num_workers=0)
test_batch = DataLoader(dataset=test_dataset,batch_size=128,shuffle=False,num_workers=0)

### Model setting

In [None]:
class Color_Sent_BERT_L0(nn.Module):
    def __init__(self, emb_dim, hidden_dim=100, output_dim=1):
        super(Color_Sent_BERT_L0,self).__init__()
        self.linear01 = nn.Linear(3+emb_dim,hidden_dim)
        self.linear02 = nn.Linear(hidden_dim,hidden_dim)
        self.linear03 = nn.Linear(hidden_dim,output_dim)

    def forward(self, color_rgb, context_embs):
        #print(color_rgb.shape, context_embs.shape)
        x = torch.hstack((color_rgb,context_embs))
        x = F.relu(self.linear01(x))
        x = F.relu(self.linear02(x))
        x = F.relu(self.linear02(x))
        y = self.linear03(x)
        y_hat = torch.sigmoid(y)
        return y_hat

In [None]:
model = Color_Sent_BERT_L0(emb_dim=emb_dim)
model.to(device)

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters())
epoch = 30

### Training

In [None]:
train_loss_list = []
test_loss_list = []
best_loss = 100
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0

    model.train()
    #print("Start Training")
    for data,label in train_batch:
        colors = data[0].to(device)
        contexts = data[1].to(device)
        label = label.to(device)
        #print(torch.sum(label))
        optimizer.zero_grad()
        y_pred = model(colors,contexts)
        #print(y_pred.shape,label.shape)
        loss = criterion(y_pred,label.view(-1,1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    batch_train_loss = train_loss/len(train_batch)

    model.eval()
    #print("Start Evaluation")
    for data,label in test_batch:
        colors = data[0].to(device)
        contexts = data[1].to(device)
        label = label.to(device)
        #print(torch.sum(label))
        y_pred = model(colors,contexts)
        loss = criterion(y_pred,label.view(-1,1))
        test_loss += loss.item()
    batch_test_loss = test_loss/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    if batch_test_loss < best_loss:
        torch.save(model.to(device).state_dict(),"model_params/bert_sum_l0.pth")
        best_loss = batch_test_loss

## Test accuracy

In [None]:
model = Color_Sent_BERT_L0(emb_dim=emb_dim).to(device)
model.load_state_dict(torch.load("model_params/bert_sum_l0.pth",map_location=device))

In [None]:
test_correct = 0
total_data = 0
model.eval()
#print("Start Evaluation")
for data,label in test_batch:
    colors = data[0].to(device)
    contexts = data[1].to(device)
    label = label.to(device)
    #print(label)
    total_data += len(label)
    #print(torch.sum(label))
    y_pred = model(colors,contexts)
    y_class = torch.where(y_pred>0.5,1,0).view(-1)
    #print(y_class)
    test_correct += torch.sum(y_class==label).item()
print("Total number of data for this evaluatio is ",total_data)
print("Classification accuracy is ",test_correct/total_data)