In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatch
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from shapeworld_data import load_raw_data, get_vocab, ShapeWorld

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def check_raw_data(imgs, labels, langs, id=0):
    data = list(zip(imgs,labels,langs))
    img_list,label,lang = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 2))
    fig.suptitle(" ".join(lang))
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose((2,1,0))
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

## Main test code

### prepare data

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data\shapeworld_np")
print(data_path)
data_list = os.listdir(data_path)
print(data_list)

In [None]:
vocab = get_vocab([os.path.join(data_path,d) for d in data_list])
print(vocab["w2i"])

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[0]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"])
train_batch0 = DataLoader(ShapeWorld(load_raw_data(os.path.join(data_path,data_list[0])), vocab), batch_size=32, shuffle=False)
train_batch1 = DataLoader(ShapeWorld(load_raw_data(os.path.join(data_path,data_list[1])), vocab), batch_size=32, shuffle=False)
train_batch2 = DataLoader(ShapeWorld(load_raw_data(os.path.join(data_path,data_list[2])), vocab), batch_size=32, shuffle=False)
train_batch3 = DataLoader(ShapeWorld(load_raw_data(os.path.join(data_path,data_list[3])), vocab), batch_size=32, shuffle=False)
train_batchs = [train_batch0, train_batch1, train_batch2, train_batch3]

In [None]:
d = load_raw_data(os.path.join(data_path,data_list[4]))
print(d["imgs"].shape)
print(d["labels"].shape)
print(d["langs"].shape)
check_raw_data(d["imgs"],d["labels"],d["langs"])
test_batch = DataLoader(ShapeWorld(d, vocab), batch_size=32, shuffle=False)

### Model setting

In [None]:
from vision import ConvNet

class ShapeWorld_RNN_L0(nn.Module):
    def __init__(self,vocab_size,emb_dim=768,hidden_dim=1024) -> None:
        super(ShapeWorld_RNN_L0,self).__init__()
        self.embedding_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.cnn_encoder = ConvNet(4)
        self.rnn = nn.GRU(emb_dim, hidden_dim, dropout=0.5, batch_first=True, bidirectional=True)

    def embed_features(self, feats):
        batch_size = feats.shape[0]
        n_obj = feats.shape[1]
        rest = feats.shape[2:]
        feats_flat = feats.reshape(batch_size * n_obj, *rest)
        feats_emb_flat = self.cnn_encoder(feats_flat)
        cnn_emb = feats_emb_flat.unsqueeze(1).view(batch_size, n_obj, -1)
        return cnn_emb

    def forward(self,imgs,contexts):
        embs = self.embedding(contexts)
        _, hidden = self.rnn(embs)
        lang_embs = hidden[-1].view(-1,self.hidden_dim)
        imgs_emb = self.embed_features(imgs)
        #print(imgs_emb.shape,lang_embs.shape)
        scores = F.softmax(torch.einsum('ijh,ih->ij', (imgs_emb, lang_embs)))
        return scores


In [None]:
model = ShapeWorld_RNN_L0(len(vocab["w2i"].keys()))
model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
epoch = 100

### Training

In [None]:
train_loss_list = []
test_loss_list = []
best_score = 100
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0

    model.train()
    #print("Start Training")
    for train_batch in train_batchs:
        for imgs,labels,langs in train_batch:
            imgs,labels,langs = imgs.to(torch.float).to(device),labels.to(torch.float).to(device),langs.to(device)
            #print(imgs.shape,labels.shape,langs.shape)
            optimizer.zero_grad()
            y_pred = model(imgs,langs)
            #print(y_pred.shape,label.shape)
            loss = criterion(y_pred,labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        batch_train_loss = train_loss/len(train_batch)

    model.eval()
    #print("Start Evaluation")
    for imgs,labels,langs in test_batch:
        imgs,labels,langs = imgs.to(torch.float).to(device),labels.to(torch.float).to(device),langs.to(device)
        #print(torch.sum(label))
        y_pred = model(imgs,langs)
        loss = criterion(y_pred,labels)
        test_loss += loss.item()
    batch_test_loss = test_loss/len(test_batch)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)
    if batch_train_loss < best_score:
        best_score = batch_train_loss
        torch.save(model.to(device).state_dict(),"model_params/shapeworld_rnn_full-data_l0.pth")

In [None]:
torch.save(model.to(device).state_dict(),"model_params/shapeworld_rnn_full-data_100epoch_l0_last.pth")

### Accuracy Test

In [None]:
correct_num = 0
total_num = 0
for imgs,labels,langs in test_batch:
    imgs,labels,langs = imgs.to(torch.float).to(device),labels.to(torch.float).to(device),langs.to(device)
    #print(torch.sum(label))
    y_pred_prob = model(imgs,langs)
    #print(y_pred_prob)
    y_pred = torch.max(y_pred_prob,1)[1]
    labels = torch.max(labels,1)[1]
    #print(y_pred,labels)
    correct_num += torch.sum(y_pred==labels).item()
    total_num += len(labels)
    
print("Total number of data for this evaluatio is ",total_num)
print("Classification accuracy is ",correct_num/total_num)