In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from shapeworld_data import load_raw_data, get_vocab, ShapeWorld

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def check_raw_data(imgs, labels, langs, id=0):
    data = list(zip(imgs,labels,langs))
    img_list,label,lang = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 2))
    fig.suptitle(" ".join(lang))
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose((2,1,0))
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

## Main test code

### prepare data

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data\shapeworld_np")
print(data_path)
data_list = os.listdir(data_path)
print(data_list)

In [None]:
vocab = get_vocab([os.path.join(data_path,d) for d in data_list])
print(vocab["w2i"])
COLOR = {"white":[1,0,0,0,0,0], "green":[0,1,0,0,0,0], "gray":[0,0,1,0,0,0], "yellow":[0,0,0,1,0,0], "red":[0,0,0,0,1,0], "blue":[0,0,0,0,0,1], "other":[0,0,0,0,0,0]}
SHAPE = {"shape":[0,0,0,0], "square":[1,0,0,0], "circle":[0,1,0,0], "rectangle":[0,0,1,0], "ellipse":[0,0,0,1]} 

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[0]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"])
train_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[-1]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"])
test_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

### Model setting

#### CNN encoder

In [None]:
from cs_cnn import CNN_encoder
class SimpleBaseLine_ShapeWorld_L0(nn.Module):
    def __init__(self,vocab_size, emb_dim=1024) -> None:
        super(SimpleBaseLine_ShapeWorld_L0,self).__init__()
        self.embedding_dim = emb_dim
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.cnn_color_encoder = CNN_encoder(6)
        self.cnn_color_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_color_model.pth",map_location=device))
        self.cnn_shape_encoder = CNN_encoder(4)
        self.cnn_shape_encoder.load_state_dict(torch.load("model_params/shapeworld_original-cnn_shape_model.pth",map_location=device))
        self.to_hidden = nn.Linear(10,self.embedding_dim)

    def embed_features(self, imgs):
        imgs01 = imgs[:,0]
        imgs02 = imgs[:,1]
        imgs03 = imgs[:,2]
        feats_emb_flats = []
        for imgs in [imgs01,imgs02,imgs03]:
            color_embs = self.cnn_color_encoder(imgs)
            shape_embs = self.cnn_shape_encoder(imgs)
            feats_emb = torch.hstack((color_embs, shape_embs))
            feats_emb_flats.append(feats_emb)
        cnn_emb = torch.stack(tuple(feats_emb_flats),dim=1)
        feat_embs = self.to_hidden(cnn_emb)
        #print(feat_embs.shape)
        return feat_embs

    def forward(self,imgs,contexts):
        embs = self.embedding(contexts)
        lang_embs = torch.sum(embs,dim=1).reshape(-1,self.embedding_dim)
        imgs_emb = self.embed_features(imgs)
        #print(imgs_emb.shape,lang_embs.shape)
        scores = F.softmax(torch.einsum('ijh,ih->ij', (imgs_emb, lang_embs)),dim=-1)
        return scores


In [None]:
model = SimpleBaseLine_ShapeWorld_L0(len(vocab["w2i"].keys()))
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
epoch = 30

### Training

In [None]:
def get_relative_accuracy(model,test_batch):
    correct_num = 0
    total_num = 0
    for imgs,labels,langs in test_batch:
        imgs,labels,langs = imgs.to(torch.float).to(device),labels.to(torch.float).to(device),langs.to(device)
        y_pred_prob = model(imgs,langs)
        y_pred = torch.max(y_pred_prob,1)[1]
        labels = torch.max(labels,1)[1]
        correct_num += torch.sum(y_pred==labels).item()
        total_num += len(labels)
    return correct_num/total_num

In [None]:
def train_model(model,train_batch,criterion,optimizer,do_break=False):
    train_loss = 0
    model.train()
    for imgs,labels,langs in train_batch:
        imgs,labels,langs = imgs.to(torch.float).to(device),labels.to(torch.float).to(device),langs.to(device)
        optimizer.zero_grad()
        y_pred = model(imgs,langs)
        loss = criterion(y_pred,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if do_break: break
    batch_train_loss = train_loss/len(train_batch)
    batch_train_acc = get_relative_accuracy(model, train_batch)
    return batch_train_loss, batch_train_acc

def eval_model(model,test_batch,criterion,do_break=False):
    test_loss = 0
    model.eval()
    with torch.no_grad():
        for imgs,labels,langs in test_batch:
            imgs,labels,langs = imgs.to(torch.float).to(device),labels.to(torch.float).to(device),langs.to(device)
            y_pred = model(imgs,langs)
            loss = criterion(y_pred,labels)
            test_loss += loss.item()
            if do_break: break
    batch_test_loss = test_loss/len(test_batch)
    batch_test_acc = get_relative_accuracy(model,test_batch)
    return batch_test_loss, batch_test_acc


def train_and_eval_epochs(model,criterion,optimizer,epoch,train_batch,test_batch,train_size,log=True,do_break=False):
    train_loss_list = []
    train_acc_list = []
    test_loss_list = []
    test_acc_list = []
    best_loss = 100
    best_acc = 0
    for i in range(epoch):
        if log:
            print("##############################################")
            print("Epoch:{}/{}".format(i+1,epoch))
        batch_train_loss, batch_train_acc = train_model(model,train_batch,criterion,optimizer,do_break=do_break)
        batch_test_loss, batch_test_acc = eval_model(model,test_batch,criterion,do_break=do_break)
        if log:
            print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
            print("Train Acc:{:.2E}, Test Acc:{:.2E}".format(batch_train_acc,batch_test_acc))
        train_loss_list.append(batch_train_loss)
        test_loss_list.append(batch_test_loss)
        train_acc_list.append(batch_train_acc)
        test_acc_list.append(batch_test_acc)
        if batch_test_loss < best_loss:
            if log: print("Best Loss saved ...")
            torch.save(model.to(device).state_dict(),"model_params/simple-l0_best-loss_trainSize="+str(train_size)+".pth")
            best_loss = batch_test_loss
        if batch_test_acc > best_acc:
            if log: print("Best Acc saved ...")
            torch.save(model.to(device).state_dict(),"model_params/simple-l0_best-acc_trainSize="+str(train_size)+".pth")
            best_acc = batch_test_acc
        if do_break: break
    return train_loss_list,test_loss_list,train_acc_list,test_acc_list

In [None]:
# train and eval with epoch
tr_loss,ts_loss,tr_acc,ts_acc = train_and_eval_epochs(model,criterion,optimizer,epoch,train_batch,test_batch,train_size=1000,log=True,do_break=False)
metrics = np.array([tr_loss,ts_loss,tr_acc,ts_acc])
np.save("metrics/simple-l0.npy",metrics)

### Accuracy Test

In [None]:
def get_prob_labels(lang_probs):
    lang_pred = []
    for probs in lang_probs:
        if probs[0]==probs[1] and probs[1]==probs[2]: # all same
            lang_pred.append(int(np.random.randint(3)))
        elif probs[0]==probs[1] and max(probs)==probs[0]:
            lang_pred.append(int(0 if np.random.randint(2)==0 else 1))
        elif probs[1]==probs[2] and max(probs)==probs[1]:
            lang_pred.append(int(1 if np.random.randint(2)==0 else 2))
        elif probs[0]==probs[2] and max(probs)==probs[1]:
            lang_pred.append(int(0 if np.random.randint(2)==0 else 2))
        else:
            lang_pred.append(int(torch.argmax(probs)))
    return np.array(lang_pred)

In [None]:
model = SimpleBaseLine_ShapeWorld_L0(len(vocab["w2i"].keys()))
model.to(device)
model.load_state_dict(torch.load("model_params/simple-l0_best-acc_trainSize=1000.pth",map_location=device))

correct_num = 0
total_num = 0
for imgs,labels,langs in test_batch:
    imgs,labels,langs = imgs.to(torch.float).to(device),labels.to(torch.float).to(device),langs.to(device)
    y_pred_prob = model(imgs,langs)
    y_pred = torch.max(y_pred_prob,1)[1]
    labels = torch.max(labels,1)[1]
    #y_pred = get_prob_labels(y_pred_prob)
    #labels = np.zeros(imgs.shape[0])
    correct_num += sum(y_pred==labels).item()
    total_num += len(labels)
    
print("Total number of data for this evaluatio is ",total_num)
print("Classification accuracy is ",correct_num/total_num)