In [1]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive')

# TODO: Enter the foldername in your Drive where you have saved the unzipped
# assignment folder, e.g. 'cse493g1/assignments/assignment2/'
FOLDERNAME = 'cse493g1/cse493g1project/'
assert FOLDERNAME is not None, "[!] Enter the foldername."

# Now that we've mounted your Drive, this ensures that
# the Python interpreter of the Colab VM can load
# python files from within it.
import sys
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt

from model_trainer import Trainer
from Model import GraphCaptioningModel
from model_utils import decode_captions, create_minibatch, encode_captions

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import ast
import re

class GraphImageDataset(Dataset):
    def __init__(self, csv_files, transform=None):
        self.data = pd.concat([pd.read_csv(file) for file in csv_files], ignore_index=True)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
      x, y = self.data.iloc[idx]
      x_out = str(x)
      y_out = str(y)
      return x_out, y_out

dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

In [4]:
csv_files2 = ['/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_small/data_kk0.csv', '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_small/data_cr0.csv',
              '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_small/data_gv0.csv', '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_small/data_sp0.csv',
              '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_medium/data_kk0_medium.csv', '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_medium/data_cr0_medium.csv',
              '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_medium/data_gv0_medium.csv', '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_medium/data_sp0_medium.csv',
              '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_small/data_kk1.csv', '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_small/data_cr1.csv',
              '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_small/data_gv1.csv', '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_small/data_sp1.csv',
              '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_medium/data_kk1_medium.csv', '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_medium/data_cr1_medium.csv',
              '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_medium/data_gv1_medium.csv', '/content/drive/My Drive/cse493g1/cse493g1project/datasets/datasets_medium/data_sp1_medium.csv']

dataset_mixed = GraphImageDataset(csv_files=csv_files2)

In [None]:
import PIL
import torchvision.transforms.functional as F
from torchvision.utils import make_grid
from torchvision.io import read_image
import torchvision.transforms as transform
from pathlib import Path

raw_data_clr = {}
clr_len = dataset_mixed.__len__()

graph_list = []
caption_list = []
for i in np.random.choice(clr_len, 2000):
  graph_path, caption = dataset_mixed.__getitem__(i)
  graph = F.pil_to_tensor(PIL.Image.open('/content/drive/My Drive/cse493g1/cse493g1project/datasets' + graph_path).convert('RGB'))
  graph_list.append(np.array([graph.numpy()]).reshape(graph.shape))
  caption_list.append(caption)
raw_data_clr['features'] = np.array(graph_list)
raw_data_clr['captions'] = np.array(caption_list)

In [None]:
import json

with open('/content/drive/My Drive/cse493g1/cse493g1project/color_features.json', 'w', encoding ='utf8') as json_file:
  json.dump(raw_data_clr['features'], json_file, ensure_ascii = True)

with open('/content/drive/My Drive/cse493g1/cse493g1project/color_captions.json', 'w', encoding ='utf8') as json_file:
  json.dump(raw_data_clr['captions'], json_file, ensure_ascii = True)

In [None]:
print("MIXED COLOR DATA")

print(raw_data_clr['features'].shape)
print(raw_data_clr['features'][0])
print(raw_data_clr['features'][0].shape)
print(raw_data_clr['captions'].shape)
print(raw_data_clr['captions'][0])
print(raw_data_clr['captions'][0].shape)

In [None]:
## THIS BLOCK WAS JUST FOR TESTING, THE NEXT BLOCK IS THE ACTUAL TRAINING CODE ##

from model_trainer import Trainer
from Model import GraphCaptioningModel
from model_utils import decode_captions, create_minibatch, encode_captions

data = {}

data['idx_to_word'] = ['<NULL>', '<START>', '<END>']
for i in range(100):
  data['idx_to_word'].append(str(i))
punc = ['{', '}', '[', ']', '(', ')', ':', ',', ' ']
for p in punc:
  data['idx_to_word'].append(p)

data['word_to_idx'] = {}
for i in range(len(data['idx_to_word'])):
  data['word_to_idx'][data['idx_to_word'][i]] = i

data['train_captions'] = torch.tensor(encode_captions(raw_data['captions'], data['word_to_idx'])).type(dtype)
data['train_features'] = torch.tensor(np.array([raw_data['features']])).type(dtype)
print(data['train_features'].shape)
print(data['train_captions'].shape)

transformer = GraphCaptioningModel(
          word_to_idx=data['word_to_idx'],
          wordvec_dim=256,
          max_length=2000
        ).type(dtype)


transformer_solver = Trainer(transformer, data, idx_to_word=data['idx_to_word'],
           num_epochs=10,
           batch_size=1,
           learning_rate=0.001,
           verbose=True, print_every=10,
         )

transformer_solver.train()

# Plot the training losses.
plt.plot(transformer_solver.loss_history)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training loss history')
plt.show()

In [None]:
from model_trainer import Trainer
from Model import GraphCaptioningModel
from model_utils import decode_captions, create_minibatch, encode_captions

torch.manual_seed(493)
np.random.seed(493)


data_clr = {}

data_clr['idx_to_word'] = ['<NULL>', '<START>', '<END>']
for i in range(100):
  data_clr['idx_to_word'].append(str(i))
punc = ['{', '}', '[', ']', '(', ')', ':', ',', ' ']
for p in punc:
  data_clr['idx_to_word'].append(p)

data_clr['word_to_idx'] = {}
for i in range(len(data_clr['idx_to_word'])):
  data_clr['word_to_idx'][data_clr['idx_to_word'][i]] = i

tenth = len(raw_data_clr)//10

encoded_captions = encode_captions(raw_data_clr['captions'], data_clr['word_to_idx'])

data_clr['train_captions'] = torch.tensor(encoded_captions[:tenth*8]).type(dtype)
data_clr['train_features'] = torch.tensor(raw_data_clr['features'][:tenth*8]).type(dtype)

data_clr['val_captions'] = torch.tensor(encoded_captions[tenth*8:tenth*9]).type(dtype)
data_clr['val_features'] = torch.tensor(raw_data_clr['features'][tenth*8:tenth*9]).type(dtype)

data_clr['test_captions'] = torch.tensor(encoded_captions[tenth*9:]).type(dtype)
data_clr['test_features'] = torch.tensor(raw_data_clr['features'][tenth*9:]).type(dtype)


graph_model_clr = GraphCaptioningModel(
          word_to_idx=data['word_to_idx'],
          wordvec_dim=256,
          max_length=1600
        ).type(dtype)


model_solver_clr = Trainer(graph_model_clr, data_clr, idx_to_word=data['idx_to_word'],
           num_epochs=10,
           batch_size=10,
           learning_rate=0.001,
           verbose=True, print_every=10,
         )

model_solver_clr.train()

# Plot the training losses.
plt.plot(model_solver_clr.loss_history)
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training loss history')
plt.show()

In [None]:
for split in ['train', 'val']:
    minibatch = create_minibatch(data, split=split, batch_size=2)
    gt_captions, features = minibatch
    gt_captions = decode_captions(gt_captions, data_clr['idx_to_word'])

    sample_captions = transformer.sample(features, max_length=1600)
    sample_captions = decode_captions(sample_captions, data_clr['idx_to_word'])

    for gt_caption, sample_caption, features in zip(gt_captions, sample_captions, features):
        # Skip missing URLs.
        plt.imshow(features)
        plt.title('%s\n%s\nGT:%s' % (split, sample_caption, gt_caption))
        plt.axis('off')
        plt.show()