# Train-Eval
---

## Import Libraries

In [1]:
import os
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim

from torchtext.data import BucketIterator

In [2]:
sys.path.append("../")
from meta_infomax.datasets.fudan_reviews import prepare_data, get_data

## Global Constants

In [55]:
BSIZE = 16
ENCODER_DIM = 100
CLASSIFIER_DIM = 100
NUM_TASKS = 14
EPOCHS = 1
DATASETS = ['apparel', 'baby', 'books', 'camera_photo',  'electronics', 
      'health_personal_care', 'imdb', 'kitchen_housewares', 'magazines', 
      'music', 'software', 'sports_outdoors', 'toys_games', 'video']

# Load Data

In [56]:
from torchtext.vocab import GloVe

In [5]:
# prepare_data()
train_set, dev_set, test_set, vocab = get_data()

In [6]:
train_iter, dev_iter, test_iter = BucketIterator.splits((train_set, dev_set, test_set),
                                                        batch_sizes=(BSIZE, BSIZE*2, BSIZE*2),
                                                        sort_within_batch=False,
                                                        sort_key=lambda x: len(x.text))

In [7]:
batch = next(iter(train_iter))
batch


[torchtext.data.batch.Batch of size 16]
	[.label]:[torch.LongTensor of size 16]
	[.text]:('[torch.LongTensor of size 16x707]', '[torch.LongTensor of size 16]')
	[.task]:[torch.LongTensor of size 16]

In [8]:
batch.text[0].shape, batch.label.shape, batch.task.shape

(torch.Size([16, 707]), torch.Size([16]), torch.Size([16]))

In [14]:
vocab.stoi["<pad>"]

1

# Baseline Model

In [45]:
class Encoder(nn.Module):
    
    def __init__(self,emb_dim, hidden_dim, num_layers):
        super().__init__()
        self.lstm = nn.LSTM(emb_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
    
    def forward(self, x):
        self.h0 = self.h0.to(x.device)
        self.c0 = self.c0.to(x.device)
        out, _ = self.lstm(x, (self.h0, self.c0))
        return out

In [46]:
class Classifier(nn.Module):
    
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )
        
    def forward(self, x):
        return self.layers(x)

In [47]:
class MultiTaskInfoMax(nn.Module):
    
    def __init__(self, shared_encoder, embeddings, vocab, encoder_dim, encoder_layers, classifier_dim, out_dim):
        super().__init__()
        self.emb = nn.Embedding.from_pretrained(embeddings, freeze=True, padding_idx=vocab.stoi["<pad>"])
        self.shared_encoder = shared_encoder
        self.private_encoder = Encoder(embeddings.shape[-1], encoder_dim, encoder_layers)
        self.classifier = Classifier(encoder_dim*4, classifier_dim, out_dim)
    
    def forward(self, sentences, lengths):
        sent_embed = self.emb(sentences)
        shared_out = self.shared_encoder(sent_embed)
        private_out = self.private_encoder(sent_embed)
        h = torch.cat((shared_out, private_out), dim=1)
        out = self.classifier(h)
        return out, shared_out, private_out 

# Train

## Overfit Batch

In [48]:
vocab.vectors.shape

torch.Size([65551, 300])

In [49]:
shared_encoder = Encoder(vocab.vectors.shape[1], ENCODER_DIM, 1)
shared_encoder

Encoder(
  (lstm): LSTM(300, 100, batch_first=True, bidirectional=True)
)

In [58]:
multitask_models = [MultiTaskInfoMax(shared_encoder=shared_encoder, embeddings=vocab.vectors, vocab=vocab, 
                                     encoder_dim=ENCODER_DIM,encoder_layers=1, classifier_dim=CLASSIFIER_DIM, out_dim=2)
                    for i in range(len(DATASETS))]

In [60]:
multitask_models[1]

MultiTaskInfoMax(
  (emb): Embedding(65551, 300, padding_idx=1)
  (shared_encoder): Encoder(
    (lstm): LSTM(300, 100, batch_first=True, bidirectional=True)
  )
  (private_encoder): Encoder(
    (lstm): LSTM(300, 100, batch_first=True, bidirectional=True)
  )
  (classifier): Classifier(
    (layers): Sequential(
      (0): Linear(in_features=400, out_features=100, bias=True)
      (1): ReLU()
      (2): Linear(in_features=100, out_features=2, bias=True)
    )
  )
)

In [None]:
multitask_models[batch]