In [None]:
from numpy.core.fromnumeric import mean
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
from torchtext import data
import torch.optim as optim
import argparse
import os
import pandas as pd
import matplotlib.pyplot as plt

### 3.3 Processing of the data ###
# 3.3.1
# The first time you run this will download a 862MB size file to .vector_cache/glove.6B.zip
glove = torchtext.vocab.GloVe(name="6B",dim=100) # embedding size = 100

# TextDataset is Described in Section 3.3 of Assignment 2

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, vocab, split="train"):
        data_path = "data"
        df = pd.read_csv(os.path.join(data_path, f"{split}.tsv"), sep="\t")

        # X: torch.tensor (maxlen, batch_size), padded indices
        # Y: torch.tensor of len N
        X, Y = [], []
        V = len(vocab.vectors)
        for i, row in df.iterrows():
            L = row.values[0].split()
            X.append(torch.tensor([vocab.stoi.get(w, V-1) for w in L]))  # Use the last word in the vocab as the "out-of-vocabulary" token
            Y.append(float(row.values[1]))
        self.X = X
        self.Y = torch.tensor(Y)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

# my_collate_function prepares batches
# it also pads each batch with zeroes.

class baseline(torch.nn.Module):
    def __init__(self, vocab, device):
      super().__init__()
      self.embedding = nn.Embedding.from_pretrained(vocab.vectors).to(device)
      self.linear = nn.Linear(100,1).to(device)
    def forward(self, x):
      return self.linear(torch.mean(self.embedding(x), dim = 1)).squeeze(1)

def my_collate_function(batch, device):
    # Handle the padding here
    # batch is approximately: [dataset[i] for i in range(0, batch_size)]
    # Since the dataset[i]'s contents is defined in the __getitem__() above, this collate function
    # should be set correspondingly.
    # Also: collate_function just takes one argument. To pass in additional arguments (e.g., device),
    # we need to wrap up an anonymous function (using lambda below)
    batch_x, batch_y = [], []
    max_len = 0
    for x,y in batch:
        batch_y += [y]
        max_len = max(max_len, len(x))
    for x,y in batch:
        x_p = torch.concat(
            [x, torch.zeros(max_len - len(x))]
        )
        batch_x.append(x_p)
    return torch.stack(batch_x).int().to(device), torch.tensor(batch_y).to(device)


def train_baseline(epochs, learning_rate, train_dataloader, val_dataloader, vocab, device):
  net = baseline(vocab, device)
  opt = optim.Adam(net.parameters(), lr = learning_rate)
  criterion = nn.BCEWithLogitsLoss()
  sigmoid = nn.Sigmoid()
  train_losses = []
  val_losses = []
  train_errors = []
  val_errors = []
  epochslist = []

  for epoch in range(epochs):
    print("Epoch: "+str(epoch))
    trainloss = []
    valloss = []
    trainerr = 0
    valerr = 0
    iter = 0 #count how many batches
    for sentences, labels in train_dataloader:
      opt.zero_grad()
      logits = net(sentences.to(device))
      loss = criterion(logits, labels.to(device))
      loss.backward()
      opt.step()
      trainloss += [float(loss.item())]
      corr = torch.round(sigmoid(logits)) != labels
      trainerr += int(corr.sum())
      iter += len(corr)
    train_errors += [trainerr/iter]
    train_losses += [mean(trainloss)]
    iter = 0
    for sentences, labels in val_dataloader:
      logits = net(sentences)
      loss = criterion(logits, labels.to(device))
      valloss += [float(loss.item())]
      corr = torch.round(sigmoid(logits)) != labels
      valerr += int(corr.sum())
      iter += len(corr)
    val_errors += [valerr/iter]
    val_losses += [mean(valloss)]
    epochslist += [epoch]

  fig, ax = plt.subplots()
  ax.plot(epochslist, train_losses, label = 'Training Loss')
  ax.plot(epochslist, val_losses, label = 'Validation Loss')
  ax.set_title("Training and Validation Loss")
  ax.set_xlabel("Epoch")
  ax.set_ylabel("Loss")
  ax.legend()

  fig, ax = plt.subplots()
  ax.plot(epochslist, train_errors, label = 'Training Errors')
  ax.plot(epochslist, val_errors, label = 'Validation Errors')
  ax.set_title("Training and Validation Errors")
  ax.set_xlabel("Epoch")
  ax.set_ylabel("Error")
  ax.legend()

  return net
def evaluate(net, dataloader):
  sigmoid = nn.Sigmoid()

  iter = 0
  err = 0
  for sentences, labels in dataloader:
    iter += 1
    logits = net(sentences)
    corr = torch.round(sigmoid(logits)) != labels
    err += int(corr.sum())
  accuracy = err/(iter*len(labels))
  print("Final accuracy: "+str(accuracy))
  return accuracy

#initialize dataloaders and device
batch_size = 50
#fix seed
torch.manual_seed(2)
#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print ("Using device:", device)

# 3.3.2
train_dataset = TextDataset(glove, "train")
val_dataset = TextDataset(glove, "val")
test_dataset = TextDataset(glove, "test")
overfit_dataset = TextDataset(glove, "overfit")

# 3.3.3
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size= batch_size,
    shuffle=True,
    collate_fn=lambda batch: my_collate_function(batch, device))

val_dataloader = torch.utils.data.DataLoader(
    dataset=val_dataset,
    batch_size= batch_size,
    shuffle=True,
    collate_fn=lambda batch: my_collate_function(batch, device))

test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size= batch_size,
    shuffle=True,
    collate_fn=lambda batch: my_collate_function(batch, device))

overfit_dataloader = torch.utils.data.DataLoader(
    dataset=overfit_dataset,
    batch_size= batch_size,
    shuffle=True,
    collate_fn=lambda batch: my_collate_function(batch, device))

#Run these functions to test my code
def train_overfit(): #4.4
  net = train_baseline(50, 0.001, overfit_dataloader, val_dataloader, glove, device)
  return net
def train_full(): #4.5 and 4.7
  net = train_baseline(50, 0.001, train_dataloader, val_dataloader, glove, device)
  evaluate(net, test_dataloader)
  torch.save(net.state_dict(), 'model baseline.pt')
  return net
def get_closest_words(): #4.6
  for name, param in (net.named_parameters()): #get weights
    if name == "linear.weight":
      weights = param.data

  dists = torch.cosine_similarity(glove.vectors.to(device), weights.to(device))    # compute distances to all words
  lst = sorted(enumerate(dists), key=lambda x: x[1]) # sort by distance
  for i in range(len(lst) -2, len(lst)- (20+2), -1):    # take the top n, don't consider top result (will be target itself)
    idx = lst[i][0]
    difference = lst[i][1]
    print(glove.itos[idx], "\t%5.2f" % lst[i][1])
  return