In [4]:
%matplotlib inline
!pip install -U portalocker>=2.0.0
!pip install datasets torch transformers
!wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
!unzip -qq cornell_movie_dialogs_corpus.zip
!rm cornell_movie_dialogs_corpus.zip
!mkdir datasets
!mv cornell\ movie-dialogs\ corpus/movie_conversations.txt ./datasets
!mv cornell\ movie-dialogs\ corpus/movie_lines.txt ./datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [5]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.utils.data
import math
import torch.nn.functional as F

In [6]:
# data processing
max_len = 25

def remove_punc(string):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    no_punct = ""
    for char in string:
        if char not in punctuations:
            no_punct = no_punct + char  # space is also a character
    return no_punct.lower()

corpus_movie_conv = './datasets/movie_conversations.txt'
corpus_movie_lines = './datasets/movie_lines.txt'
with open(corpus_movie_conv, 'r', encoding='iso-8859-1') as c:
    conv = c.readlines()
with open(corpus_movie_lines, 'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

# extract text
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

# generate question answer pairs
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []

        if i == len(ids) - 1:
            break

        first = remove_punc(lines_dic[ids[i]].strip())
        second = remove_punc(lines_dic[ids[i+1]].strip())
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        pairs.append(qa_pairs)


In [7]:
print(len(pairs))
print(pairs[0][0])
print(pairs[0][1])

221616
['can', 'we', 'make', 'this', 'quick', 'roxanne', 'korrine', 'and', 'andrew', 'barrett', 'are', 'having', 'an', 'incredibly', 'horrendous', 'public', 'break', 'up', 'on', 'the', 'quad', 'again']
['well', 'i', 'thought', 'wed', 'start', 'with', 'pronunciation', 'if', 'thats', 'okay', 'with', 'you']


In [8]:
min_word_freq = 5
word_freq = Counter()
for pair in pairs:
  word_freq.update(pair[0])
  word_freq.update(pair[1])

words = [word for word, c in word_freq.items() if c > min_word_freq]
word_map = {word: i+1 for i, word in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

print("Total words are: {}".format(len(word_map)))

def encode_question(words, word_map):
  enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
  return enc_c

def encode_reply(words, word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
        [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c
pairs_encoded = []
for pair in pairs:
  ques = encode_question(pair[0], word_map)
  ans = encode_reply(pair[1], word_map)
  pairs_encoded.append([ques, ans])

print(pairs_encoded[0][0])
print(pairs_encoded[0][1])

Total words are: 18243
[1, 2, 3, 4, 5, 18240, 18240, 6, 7, 8, 9, 10, 11, 12, 18240, 13, 14, 15, 16, 17, 18240, 18, 0, 0, 0]
[18241, 19, 20, 21, 22, 23, 24, 18240, 25, 26, 27, 24, 28, 18242, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


In [13]:
def subsequent_mask(size):
  attn_size = (1, size, size)
  mask = torch.triu(torch.ones(attn_size), diagonal=1).type(torch.uint16)
  return mask == 0

def create_mask(src, tgt, pad):
  # batch x seq_len
  size = tgt.size()[1]
  # batch x 1 x seq_len
  src_mask = (src != pad).unsqueeze(-2)
  tgt_mask = (tgt != pad).unsqueeze(-2)
  tgt_mask = tgt_mask & subsequent_mask(size)
  return src_mask, tgt_mask


src = torch.LongTensor([
    [1, 2, 3],
    [4, 5, 0]
    ])
tgt = torch.LongTensor([
    [1, 2, 3, 4],
    [1, 2, 0, 0]
])
pad = 0
src_mask, tgt_mask = create_mask(src, tgt, pad)
print(src_mask)
print(tgt_mask)

tensor([[[ True,  True,  True]],

        [[ True,  True, False]]])
tensor([[[ True, False, False, False],
         [ True,  True, False, False],
         [ True,  True,  True, False],
         [ True,  True,  True,  True]],

        [[ True, False, False, False],
         [ True,  True, False, False],
         [ True,  True, False, False],
         [ True,  True, False, False]]])


In [None]:
class Embeddings(nn.Module):
  def __init__(self, d_model, vocab_size):
    pass

  def forward(self, x):
    pass


class PositionalEncoding(nn.Module):
  def __init__(self, d_model, dropout, max_len=5000):
    pass

  def forward(self, x):
    pass

In [None]:
import copy

def clones(module, N):
  return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

def attention(query, key, value, mask=None, dropout=None):
  pass


class MultiHeadedAttention(nn.Module):
  def __init__(self, h, d_model, dropout=0.1):
    pass

  def forward(self, query, key, value, mask=None):
    pass

In [None]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_ff, dropout):
    pass

  def forward(self, x):
    pass


class LayerNorm(nn.Module):
  def __init__(self, features, eps=1e-6):
    pass

  def forward(self, x):
    pass

class SublayerConnection(nn.Module):
  def __init__(self, siez, dropout):
    pass

  def forward(self, x, sublayer):
    pass

In [None]:
class EncoderLayer(nn.Module):
  def __init__(self, size, self_attn, ffn, dropout):
    pass

  def forward(self, x, mask):
    pass


class Encoder(nn.Module):
  def __init__(self, layer, N):
    pass

  def forward(self, x, mask):
    pass

class DecoderLayer(nn.Module):
  def __init__(self, size, self_attn, cross_attn, fnn, dropout):
    pass

  def forward(self, x, memory, src_mask, tgt_mask):
    pass

class Decoder(nn.Module):
  def __init__(self, layer, N):
    pass

  def forwward(self, x, memory, src_mask, tgt_mask):
    pass


class Generator(nn.Module):
  def __init__(self, d_model, vocab_size):
    pass

  def forward(self, x):
    pass


class Transformer(nn.Module):
  def __init__(self, src_embed, tgt_embed, encoder, decoder, generator):
    pass

  def forward(self, src, tgt, src_mask, tgt_mask):
    pass

  def encode(self, src, src_mask):
    pass

  def decode(self, memory, src_mask, tgt, tgt_mask):
    pass

def create_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
  pass