In [1]:
# This cell makes sure modules are auto-loaded when you change external python files
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0" # limiting to one GPU

In [3]:
from datasets import load_dataset
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
import evaluate
import gensim
import transformers
import nltk
import re
import random
import numpy as np
import torch

## Enter your SCIPER here: ##
SCIPER = "369141"

try:
    assert re.match("\d{6}", SCIPER)[0] == SCIPER, "Invalid SCIPER given. please enter your correct 6-digit SCIPER number above!"
except:
    print("Invalid SCIPER given. please enter your correct 6-digit SCIPER number above!")

student_seed = int(SCIPER)


"""Set seed for reproducibility."""
random.seed(student_seed)
np.random.seed(student_seed)
torch.manual_seed(student_seed)
torch.cuda.manual_seed_all(student_seed)
torch.mps.manual_seed(student_seed)

In [4]:
# We will use NLTK to tokenize the text
from datasets import load_dataset
import nltk
nltk.download('punkt')
scomp = load_dataset("embedding-data/sentence-compression")
scomp["train"][0]
scomp_small_train = scomp["train"].filter(lambda sample: 10 < len(nltk.word_tokenize(sample["set"][0])) < 100)

scomp_small = scomp
scomp_small["train"] = scomp_small_train
scomp_train, scomp_val, scomp_test = scomp_small["train"][:10000]['set'], scomp_small["train"][10000:11000]['set'], scomp_small["train"][11000:12000]['set']

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/saidgurbuz/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
from utils import load_binary

wikitext_vocab = load_binary("wikitext_vocab.pkl")
from data import CustomTokenizer

custom_tokenizer = CustomTokenizer(vocab=wikitext_vocab)
from data import SCompDataset

MAX_SEQ_LENGTH = 128
scomp_train_ds = SCompDataset(scomp_train, custom_tokenizer, MAX_SEQ_LENGTH)
scomp_val_ds = SCompDataset(scomp_val, custom_tokenizer, MAX_SEQ_LENGTH)
scomp_test_ds = SCompDataset(scomp_test, custom_tokenizer, MAX_SEQ_LENGTH)
from torch.utils.data import DataLoader

# feel free to change batch size according to your GPU memory
scomp_train_dataloader = DataLoader(scomp_train_ds, batch_size=32, shuffle=True)
scomp_val_dataloader = DataLoader(scomp_val_ds, batch_size=32, shuffle=True)
scomp_test_dataloader = DataLoader(scomp_test_ds, batch_size=32, shuffle=True)
from modeling import VanillaLSTM

vocab_size = len(wikitext_vocab)
embedding_dim = 100
hidden_dim = 100
num_layers = 2
dropout_rate = 0.15

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = "cpu"

pretrained_encoder = VanillaLSTM(vocab_size, embedding_dim, hidden_dim,
                                  num_layers, dropout_rate=dropout_rate).to(device)

# TODO: Load the pretrained model from the file
pretrained_encoder.load_state_dict(torch.load('models/lstm_pretrained.pt'))
from modeling import EncoderDecoder

lr = 1e-3
dropout_rate = 0.15
bos_token_id = custom_tokenizer.bos_token_id
encoder_decoder = EncoderDecoder(hidden_dim, vocab_size, vocab_size, bos_token_id=bos_token_id, dropout_rate=dropout_rate, pretrained_encoder=pretrained_encoder).to(device)
optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
criterion = nn.NLLLoss(ignore_index=custom_tokenizer.pad_token_id)
num_params = sum(p.numel() for p in encoder_decoder.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 25,470,167 trainable parameters


In [6]:
import torch.nn.functional as F

In [7]:
criterion(F.log_softmax(torch.rand(32, 1000), dim=-1), torch.randint(0, 1000, (32,)))

tensor(6.9271)

In [13]:
for i in scomp_train_dataloader:
    print(i["input_ids"].shape)

torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size([32, 130])
torch.Size

In [8]:
from modeling import seq2seq_train

# ETS: ~30 mins to run with a batch size of 32 for 20 epochs
seq2seq_train(model=encoder_decoder,
              train_loader=scomp_train_dataloader,
              eval_loader=scomp_val_dataloader,
              optimizer=optimizer,
              criterion=criterion,
              device=device,
              tensorboard_path="./tensorboard/encoder_decoder")
# save the model
torch.save(encoder_decoder.state_dict(), "models/encoder_decoder.pt")

  0%|          | 0/20 [00:00<?, ?it/s]
  0%|          | 0/313 [00:00<?, ?it/s]

[A

loss: 10.827292442321777




loss: 10.665144920349121




loss: 10.530056953430176




loss: 10.391977310180664




loss: 10.207572937011719




loss: 9.967386245727539




loss: 9.90947151184082




loss: 9.544557571411133




loss: 9.427667617797852




loss: 9.171013832092285




loss: 8.851107597351074




loss: 8.31529712677002




loss: 7.917985916137695




loss: 7.563299655914307




loss: 6.685769557952881




loss: 6.076077938079834




loss: 6.503880977630615




loss: 6.405498504638672




loss: 5.745875358581543




loss: 6.5141777992248535




loss: 5.863117218017578




loss: 6.170448303222656




loss: 5.6204833984375




loss: 5.541178226470947




loss: 5.327462196350098




loss: 5.168258190155029




loss: 5.335148334503174




loss: 5.23144006729126




loss: 5.221790790557861




loss: 5.093507766723633




loss: 5.2406816482543945




loss: 5.399551868438721




loss: 5.576704978942871




loss: 5.26125955581665




loss: 5.366063117980957




loss: 5.33480167388916




loss: 5.3037638664245605




loss: 5.276780605316162




loss: 5.175810813903809




loss: 5.330368995666504




loss: 5.483868598937988




loss: 5.342341423034668




loss: 5.073168754577637




loss: 5.1142964363098145




loss: 4.959949493408203




loss: 5.170703411102295




loss: 5.246708869934082




loss: 5.391226291656494




loss: 4.677867889404297




loss: 5.472489833831787




loss: 5.466957092285156




loss: 4.984728813171387




loss: 4.943873405456543




loss: 4.555212020874023




loss: 4.767164707183838




loss: 5.17538595199585




loss: 5.096730709075928




loss: 4.69058895111084




loss: 4.718313694000244




loss: 5.2108869552612305




loss: 4.919214248657227




loss: 4.7175211906433105




loss: 4.610531806945801




loss: 5.047734260559082




loss: 4.812112808227539




loss: 5.009567737579346




loss: 5.182263374328613




loss: 5.368485927581787




loss: 4.924344539642334




loss: 4.757671356201172




loss: 5.276025772094727




loss: 5.250858783721924




loss: 4.945662975311279




loss: 5.077049732208252




loss: 4.6234517097473145




loss: 4.900463104248047




loss: 4.301915168762207




loss: 4.496082305908203




loss: 4.596930027008057




loss: 4.490263938903809




loss: 5.049514293670654




loss: 4.5372185707092285




loss: 5.268503189086914




loss: 5.0294880867004395




loss: 5.128412246704102




loss: 4.1343512535095215




loss: 4.296866416931152




loss: 4.8502068519592285




loss: 4.6646904945373535




loss: 4.9687418937683105




loss: 4.196323871612549




loss: 4.897042274475098




loss: 4.7117919921875




loss: 4.862868309020996




loss: 5.157594203948975




loss: 4.5524492263793945




loss: 4.87459659576416




loss: 4.730757236480713




loss: 4.76134729385376




loss: 4.61478328704834




[1,   100] loss: 5.732
loss: 4.418924331665039




loss: 4.841302394866943




loss: 4.892247676849365




loss: 4.72377872467041




loss: 4.872066020965576




loss: 4.776157855987549




loss: 4.801754951477051




loss: 4.38850736618042




loss: 5.050229072570801




loss: 5.328547954559326




loss: 4.8201904296875




loss: 4.575374603271484




loss: 5.13273286819458




loss: 4.671786785125732




loss: 4.676393508911133




loss: 4.572262763977051




loss: 4.717039585113525




loss: 4.6082844734191895




loss: 4.801655292510986




loss: 4.4975128173828125




loss: 4.909696578979492




loss: 4.999222278594971




loss: 5.301118850708008




loss: 4.573108673095703




loss: 4.460090160369873




loss: 4.882234573364258




loss: 4.262575626373291




loss: 4.563714027404785




loss: 4.628406047821045




loss: 5.073214054107666




loss: 4.8282575607299805




loss: 4.819249629974365




loss: 4.874283313751221




loss: 4.497954368591309




loss: 4.811588287353516




loss: 4.606951713562012




loss: 4.587220668792725




loss: 5.293520450592041




loss: 4.801088333129883




loss: 4.9080657958984375




loss: 5.060977458953857




loss: 4.651212692260742




loss: 4.5231451988220215




loss: 4.593672275543213




loss: 4.744082927703857




loss: 4.663118362426758




loss: 4.684500694274902




loss: 4.370093822479248




loss: 4.669507026672363




loss: 4.0458292961120605




loss: 4.4615020751953125




loss: 4.559922218322754




loss: 4.822371006011963




loss: 4.6826372146606445




loss: 4.7652201652526855




loss: 4.243335247039795




loss: 4.912631034851074




loss: 4.909542560577393




loss: 4.479888916015625




loss: 4.587503433227539




loss: 4.873819828033447




loss: 4.848729610443115




loss: 4.317362308502197




loss: 4.495446681976318




loss: 4.5450873374938965




loss: 4.534912586212158




loss: 4.783287048339844




loss: 4.882073879241943




loss: 4.055534362792969




loss: 4.369653701782227




loss: 4.5197343826293945




loss: 4.692651748657227




loss: 4.991121768951416




loss: 4.735671520233154




loss: 5.066282272338867




loss: 5.027756690979004




loss: 4.421435832977295




loss: 4.867494106292725




loss: 5.047513484954834




loss: 4.47902774810791




loss: 4.57152795791626




loss: 4.539588451385498




loss: 4.57363224029541




loss: 4.464613914489746




loss: 4.338709354400635




loss: 4.919558525085449




loss: 4.437676429748535




loss: 4.955129623413086




loss: 4.69556188583374




loss: 4.687591552734375




loss: 4.248404026031494




loss: 4.4623541831970215




loss: 4.561647891998291




loss: 4.809512615203857




loss: 4.620838642120361




loss: 4.446601390838623




loss: 4.485877513885498




loss: 4.470287322998047




loss: 5.2063422203063965




loss: 3.900991439819336




[1,   200] loss: 4.682
loss: 4.712407112121582




loss: 5.058010101318359




loss: 4.656675338745117




loss: 5.124155521392822




loss: 4.797517776489258




loss: 4.294869899749756




loss: 4.58936882019043




loss: 4.424036502838135




loss: 4.545863628387451




loss: 4.428104400634766




loss: 4.704966068267822




loss: 4.914971351623535




loss: 4.601301670074463




loss: 4.899731159210205




loss: 5.268355846405029




loss: 4.695630073547363




loss: 4.447272300720215




loss: 4.062030792236328




loss: 5.2621636390686035




loss: 4.619400978088379




loss: 4.516523838043213




loss: 4.708861827850342




loss: 4.852879524230957




loss: 4.6085638999938965




loss: 5.067668914794922




loss: 4.531530380249023




loss: 4.449597358703613




loss: 4.6188530921936035




loss: 4.522570610046387




loss: 4.535219192504883




loss: 4.6304240226745605




loss: 4.440567493438721




loss: 4.752870559692383




loss: 4.557904243469238




loss: 4.47506046295166




loss: 4.540004253387451




loss: 4.019121170043945




loss: 4.673568248748779




loss: 4.689591407775879




loss: 5.028331756591797




loss: 5.038360118865967




loss: 4.792250633239746




loss: 4.4516282081604




loss: 4.232194900512695




loss: 4.647489547729492




loss: 5.093954563140869




loss: 4.2804131507873535




loss: 4.489591121673584




loss: 5.044486045837402




loss: 4.272844314575195




loss: 4.67511510848999




loss: 4.741417407989502




loss: 4.354147911071777




loss: 4.621799468994141




loss: 4.525256633758545




loss: 4.781246185302734




loss: 4.4999494552612305




loss: 4.599035739898682




loss: 4.888777256011963




loss: 4.791810989379883




loss: 4.1468024253845215




loss: 4.292333602905273




loss: 4.179028511047363




loss: 4.815064907073975




loss: 4.605524063110352




loss: 4.883631706237793




loss: 4.165297508239746




loss: 4.815789222717285




loss: 4.903398513793945




loss: 4.603353023529053




loss: 4.7166619300842285




loss: 4.3577561378479




loss: 4.355078220367432




loss: 4.756087779998779




loss: 4.870640277862549




loss: 4.13724946975708




loss: 4.5989460945129395




loss: 4.478877544403076




loss: 4.7021331787109375




loss: 4.454265594482422




loss: 4.446233749389648




loss: 4.75246000289917




loss: 4.000689506530762




loss: 4.506263732910156




loss: 4.5240397453308105




loss: 4.576953887939453




loss: 4.3728508949279785




loss: 4.808928489685059




loss: 4.704259872436523




loss: 4.53375244140625




loss: 4.548285484313965




loss: 4.471199989318848




loss: 4.768104076385498




loss: 4.529867172241211




loss: 4.7271904945373535




loss: 4.476916790008545




loss: 4.120922088623047




loss: 4.658932209014893




loss: 4.848072052001953




loss: 4.86635160446167




[1,   300] loss: 4.612
loss: 4.851776123046875




loss: 4.829355239868164




loss: 4.390937328338623




loss: 4.55363130569458




loss: 4.706575393676758




loss: 4.24821662902832




loss: 4.665327072143555




loss: 4.55312442779541




loss: 4.060487747192383




loss: 4.1394944190979




loss: 4.34996223449707




loss: 4.246022701263428




loss: 4.8211350440979




Epoch 1 | Train Loss: 4.9875


  0%|          | 0/32 [00:00<?, ?it/s]
                                      

ValueError: Expected input batch_size (4160) to match target batch_size (4128).