In [None]:
from pathlib import Path
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from shapeworld_data import load_raw_data, get_vocab, ShapeWorld, All_langs_ShapeWorld

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
def check_raw_data(imgs, labels, langs, id=0):
    data = list(zip(imgs,labels,langs))
    img_list,label,lang = data[id]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(6, 2))
    fig.suptitle(" ".join(lang))
    for i,(l,img) in enumerate(zip(label,img_list)):
        img = img.transpose((2,1,0))
        axes[i].imshow(img)
        if l==1: axes[i].set_title("Correct")
    plt.show()

In [None]:
import torchvision
def imshow(img):
    img = torchvision.utils.make_grid(img)
    img = img / 2 + 0.5
    npimg = img.detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

## Prepare data

In [None]:
root = Path(os.path.abspath('')).parent.parent.absolute()
data_path = os.path.join(root,"data\shapeworld_np_all_langs")
print(data_path)
data_list = os.listdir(data_path)
print(data_list)

In [None]:
vocab = get_vocab([os.path.join(data_path,d) for d in data_list])
print(vocab["w2i"])

In [None]:
COLOR = {"white":[1,0,0,0,0,0], "green":[0,1,0,0,0,0], "gray":[0,0,1,0,0,0], "yellow":[0,0,0,1,0,0], "red":[0,0,0,0,1,0], "blue":[0,0,0,0,0,1], "other":[0,0,0,0,0,0]}
SHAPE = {"shape":[0,0,0,0], "square":[1,0,0,0], "circle":[0,1,0,0], "rectangle":[0,0,1,0], "ellipse":[0,0,0,1]} 

In [None]:
train_imgs = []
for data in data_list[:-1]:
    d = load_raw_data(os.path.join(data_path,data))
    img_set = d["imgs"].reshape(-1,3,64,64)
    train_imgs.append(img_set)
train_imgs = torch.tensor(np.array(train_imgs).reshape(-1,3,64,64), dtype=torch.float)
print(train_imgs.shape)

test_imgs = torch.tensor(load_raw_data(os.path.join(data_path,data_list[-1]))["imgs"].reshape(-1,3,64,64), dtype=torch.float)
print(test_imgs.shape)

In [None]:
imshow(train_imgs[:32])
imshow(test_imgs[:32])

In [None]:
def utter2tensor(utter):
    #print(utter)
    utters = utter.split(" ")
    if len(utters) == 1 and utters[0] in COLOR.keys():
        return torch.tensor(np.array(COLOR[utters[0]]+SHAPE["shape"]))
    elif len(utters) == 1 and utters[0] in SHAPE.keys():
        return torch.tensor(np.array(COLOR["other"]+SHAPE[utters[0]]))
    elif len(utters) == 2:
        return torch.tensor(np.array(COLOR[utters[0]]+SHAPE[utters[1]]))
    else:
        return torch.tensor(np.array(COLOR["other"]+SHAPE["shape"]))
    

In [None]:
from functools import reduce

train_labels = []
for data in data_list[:-1]:
    d = load_raw_data(os.path.join(data_path,data))
    utter_set = reduce(lambda x,y:x+y,[" ".join(u).split(" # ") for u in d["langs"]])
    tensor_set = torch.vstack(tuple([utter2tensor(u) for u in utter_set]))
    #print(utter_set[:10])
    #print(tensor_set[:10])
    #print(tensor_set.shape)
    train_labels.append(tensor_set)
train_labels = torch.vstack(tuple(train_labels))
print(train_labels.shape)

d = load_raw_data(os.path.join(data_path,data_list[-1]))
utter_set = reduce(lambda x,y:x+y,[" ".join(u).split(" # ") for u in d["langs"]])
test_labels = torch.vstack(tuple([utter2tensor(u) for u in utter_set]))
print(test_labels.shape)

In [None]:
train_color_labels, test_color_labels = train_labels[:,:6], test_labels[:,:6]
train_shape_labels, test_shape_labels = train_labels[:,6:], test_labels[:,6:]
print(train_color_labels.shape, test_color_labels.shape)
print(train_shape_labels.shape, test_shape_labels.shape)

In [None]:
color_train_loader = DataLoader(list(zip(train_imgs,train_color_labels)),batch_size=32, shuffle=True)
color_test_loader = DataLoader(list(zip(test_imgs,test_color_labels)),batch_size=32, shuffle=False)
shape_train_loader = DataLoader(list(zip(train_imgs,train_shape_labels)),batch_size=32, shuffle=True)
shape_test_loader = DataLoader(list(zip(test_imgs,test_shape_labels)),batch_size=32, shuffle=False)

## Model

In [None]:
class CNN_encoder(nn.Module):
    def __init__(self,output_dim):
        super(CNN_encoder, self).__init__()
        self.output_dim = output_dim
        self.enc = torch.load("model_params/cnn_autoencoder3-16-32.cnnet")
        for layer in self.enc:
            if not (str(layer)=="ReLU()" or str(layer).startswith("MaxPool2d")):
                layer.weight.requires_grad = False
        self.fc1 = nn.Linear(32*16*16,1000)
        self.fc2 = nn.Linear(1000,100)
        self.fc3 = nn.Linear(100, self.output_dim)
    
    def forward(self,img):
        x = self.enc(img)
        #print(x.shape)
        x = x.reshape(-1, 32*16*16)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        y_prob = F.softmax(self.fc3(x),dim=1)
        return y_prob

In [None]:
model = CNN_encoder(6).to(device)

for imgs,labels in color_train_loader:
    imgs, labels = imgs.to(device), labels.to(device)
    yprob = model(imgs)
    print(yprob.shape)
    break


## Training

### Color training

In [None]:
color_model = CNN_encoder(6).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(color_model.parameters())
epoch = 5

In [None]:
train_loss_list = []
test_loss_list = []
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0

    model.train()
    #print("Start Training")
    for imgs,labels in color_train_loader:
        imgs,labels = imgs.to(torch.float).to(device),labels.to(torch.float).to(device)
        optimizer.zero_grad()
        y_pred = color_model(imgs)
        loss = criterion(y_pred,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    batch_train_loss = train_loss/len(color_train_loader)

    model.eval()
    #print("Start Evaluation")
    for imgs,labels in color_test_loader:
        imgs,labels = imgs.to(torch.float).to(device),labels.to(torch.float).to(device)
        y_pred = color_model(imgs)
        loss = criterion(y_pred,labels)
        test_loss += loss.item()
    batch_test_loss = test_loss/len(color_test_loader)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)

In [None]:
torch.save(color_model.to(device).state_dict(),"model_params/shapeworld_cnn_color_model.pth")

#### Accuracy test

In [None]:
data = [(i,c) for i,c in zip(test_imgs,test_color_labels) if sum(c)!=0]
print(len(data),len(test_imgs))
color_eval_loader = DataLoader(data,batch_size=128, shuffle=False)

In [None]:
correct_num = 0
total_num = 0
for imgs,labels in color_eval_loader:
    imgs,labels = imgs.to(torch.float).to(device),labels.to(torch.float).to(device)
    #print(torch.sum(label))
    y_pred_prob = color_model(imgs)
    #print(y_pred_prob)
    y_pred = torch.max(y_pred_prob,1)[1]
    labels = torch.max(labels,1)[1]
    #print(y_pred,labels)
    correct_num += torch.sum(y_pred==labels).item()
    total_num += len(labels)
    
print("Total number of data for this evaluatio is ",total_num)
print("Classification accuracy is ",correct_num/total_num)

In [None]:
show = 32
imshow(test_imgs[:show])
for imgs,labels in color_test_loader:
    imgs,labels = imgs.to(torch.float).to(device),labels.to(torch.float).to(device)
    y_pred_prob = color_model(imgs)
    break
colors = torch.max(y_pred_prob,1)[1][:show]
labels = torch.max(test_color_labels,1)[1][:show]
print(np.array([list(COLOR.keys())[c] for c in colors]).reshape(-1,8))
#print(utter_set[:show])

### Shape training

In [None]:
shape_model = CNN_encoder(4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(shape_model.parameters())
epoch = 5

In [None]:
train_loss_list = []
test_loss_list = []
for i in range(epoch):
    print("##############################################")
    print("Epoch:{}/{}".format(i+1,epoch))
    train_loss = 0
    test_loss = 0

    model.train()
    #print("Start Training")
    for imgs,labels in shape_train_loader:
        imgs,labels = imgs.to(torch.float).to(device),labels.to(torch.float).to(device)
        optimizer.zero_grad()
        y_pred = shape_model(imgs)
        loss = criterion(y_pred,labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    batch_train_loss = train_loss/len(shape_train_loader)

    model.eval()
    #print("Start Evaluation")
    for imgs,labels in shape_test_loader:
        imgs,labels = imgs.to(torch.float).to(device),labels.to(torch.float).to(device)
        y_pred = shape_model(imgs)
        loss = criterion(y_pred,labels)
        test_loss += loss.item()
    batch_test_loss = test_loss/len(shape_test_loader)

    print("Train Loss:{:.2E}, Test Loss:{:.2E}".format(batch_train_loss,batch_test_loss))
    train_loss_list.append(batch_train_loss)
    test_loss_list.append(batch_test_loss)

In [None]:
torch.save(color_model.to(device).state_dict(),"model_params/shapeworld_cnn_shape_model.pth")

#### Accuracy test

In [None]:
data = [(i,c) for i,c in zip(test_imgs,test_shape_labels) if sum(c)!=0]
print(len(data),len(test_imgs))
shape_eval_loader = DataLoader(data,batch_size=128, shuffle=False)

In [None]:
correct_num = 0
total_num = 0
for imgs,labels in shape_eval_loader:
    imgs,labels = imgs.to(torch.float).to(device),labels.to(torch.float).to(device)
    #print(torch.sum(label))
    y_pred_prob = color_model(imgs)
    #print(y_pred_prob)
    y_pred = torch.max(y_pred_prob,1)[1]
    labels = torch.max(labels,1)[1]
    #print(y_pred,labels)
    correct_num += torch.sum(y_pred==labels).item()
    total_num += len(labels)
    
print("Total number of data for this evaluatio is ",total_num)
print("Classification accuracy is ",correct_num/total_num)

In [None]:
show = 16
imshow(test_imgs[:show])
for imgs,labels in shape_test_loader:
    imgs,labels = imgs.to(torch.float).to(device),labels.to(torch.float).to(device)
    y_pred_prob = shape_model(imgs)
    break
shapes = torch.max(y_pred_prob,1)[1][:show]
labels = torch.max(test_shape_labels,1)[1][:show]
print(np.array([list(SHAPE.keys())[1:][s] for s in shapes]).reshape(-1,8))
print("\n")
print(np.array([list(SHAPE.keys())[1:][s] for s in labels]).reshape(-1,8))