## Get dataset from Kaggle

In [None]:
!mkdir -p ~/.kaggle
!touch ~/.kaggle/kaggle.json
!chmod 600 ~/.kaggle/kaggle.json
!echo '{"username":"rohitawate","key":"14a69194fa4cd4e37490796b1f37ff69"}' > ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d nasirkhalid24/the-office-us-complete-dialoguetranscript

Downloading the-office-us-complete-dialoguetranscript.zip to /content
  0% 0.00/1.37M [00:00<?, ?B/s]
100% 1.37M/1.37M [00:00<00:00, 32.6MB/s]


In [None]:
!ls

sample_data  the-office-us-complete-dialoguetranscript.zip


In [None]:
!unzip the-office-us-complete-dialoguetranscript.zip
!ls

Archive:  the-office-us-complete-dialoguetranscript.zip
  inflating: The-Office-Lines-V4.csv  
sample_data		 the-office-us-complete-dialoguetranscript.zip
The-Office-Lines-V4.csv


In [None]:
!mv The-Office-Lines-V4.csv data.csv

# EDA and Pre-processing

In [None]:
import pandas as pd
df = pd.read_csv("data.csv")
df.head()

Unnamed: 0,season,episode,title,scene,speaker,line,Unnamed: 6
0,1,1,Pilot,1,Michael,All right Jim. Your quarterlies look very good...,
1,1,1,Pilot,1,Jim,"Oh, I told you. I couldn't close it. So...",
2,1,1,Pilot,1,Michael,So you've come to the master for guidance? Is ...,
3,1,1,Pilot,1,Jim,"Actually, you called me in here, but yeah.",
4,1,1,Pilot,1,Michael,"All right. Well, let me show you how it's done.",


### Drop unnecessary columns

In [None]:
df = df.drop(columns=["season", "episode", "Unnamed: 6"], axis=1)
df.head()

Unnamed: 0,title,scene,speaker,line
0,Pilot,1,Michael,All right Jim. Your quarterlies look very good...
1,Pilot,1,Jim,"Oh, I told you. I couldn't close it. So..."
2,Pilot,1,Michael,So you've come to the master for guidance? Is ...
3,Pilot,1,Jim,"Actually, you called me in here, but yeah."
4,Pilot,1,Michael,"All right. Well, let me show you how it's done."


## Drop low-quality data

The dataset is likely compiled from various sources and some lines have the
speaker names in the following style: "Michael: " i.e. with a trailing colon and
space. As can be seen below, these are low quality samples. We drop them.

In [None]:
df[df["speaker"] == "Michael: "]

Unnamed: 0,title,scene,speaker,line
31793,Happy Hour,4846,Michael:,w many is that?
31795,Happy Hour,4846,Michael:,unt the last one.
31797,Happy Hour,4846,Michael:,", new record!"
31799,Happy Hour,4846,Michael:,", what did you do today?"
31801,Happy Hour,4846,Michael:,", yeah, sitting on your big fat butt. Alright,..."
...,...,...,...,...
32088,Happy Hour,4888,Michael:,is I.
32090,Happy Hour,4888,Michael:,", hey guys."
32102,Happy Hour,4890,Michael:,"y, Julie! You having fun?"
32145,Happy Hour,4896,Michael:,"lperts, wait up. Oh, what a great night. Got t..."


In [None]:
df = df[~df['speaker'].str.endswith(': ')]
df[df["speaker"] == "Michael: "]

Unnamed: 0,title,scene,speaker,line


### Correct typos: Deangelo > DeAngelo

In [None]:
typos = {
    "Deangelo": "DeAngelo"
}

df["speaker"] = df["speaker"].replace(typos, regex=True)

### Drop lines from characters that don't fall in the top 15 in terms of number of lines

In [None]:
TOP_COUNT = 15

top_speakers = df.value_counts("speaker").head(TOP_COUNT).keys()
print(top_speakers)

Index(['Michael', 'Dwight', 'Jim', 'Pam', 'Andy', 'Kevin', 'Angela', 'Erin',
       'Oscar', 'Ryan', 'Darryl', 'Phyllis', 'Kelly', 'Toby', 'Jan'],
      dtype='object', name='speaker')


In [None]:
df = df[df["speaker"].isin(top_speakers)]
df = df.reset_index(drop=True)
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 43981 entries, 0 to 43980
Data columns (total 4 columns):
 #   Column   Non-Null Count  Dtype 
---  ------   --------------  ----- 
 0   title    43981 non-null  object
 1   scene    43981 non-null  int64 
 2   speaker  43981 non-null  object
 3   line     43981 non-null  object
dtypes: int64(1), object(3)
memory usage: 1.3+ MB


In [None]:
df["speaker"].value_counts()

Michael    10773
Dwight      6752
Jim         6222
Pam         4973
Andy        3698
Kevin       1535
Angela      1534
Erin        1413
Oscar       1336
Ryan        1182
Darryl      1160
Phyllis      962
Kelly        822
Toby         814
Jan          805
Name: speaker, dtype: int64

### Save new CSV to disk

In [None]:
with open("processed_data.csv", "w") as out_fd:
    df.to_csv(out_fd, index=False)

In [None]:
!ls -lh

total 9.5M
-rw-r--r-- 1 root root 4.6M Jan 18  2021 data.csv
-rw-r--r-- 1 root root 3.5M Apr 18 23:54 processed_data.csv
drwxr-xr-x 1 root root 4.0K Apr 14 13:35 sample_data
-rw-r--r-- 1 root root 1.4M Apr 18 23:54 the-office-us-complete-dialoguetranscript.zip


In [None]:
!head -10 processed_data.csv

title,scene,speaker,line
Pilot,1,Michael,All right Jim. Your quarterlies look very good. How are things at the library?
Pilot,1,Jim,"Oh, I told you. I couldn't close it. So..."
Pilot,1,Michael,"So you've come to the master for guidance? Is this what you're saying, grasshopper?"
Pilot,1,Jim,"Actually, you called me in here, but yeah."
Pilot,1,Michael,"All right. Well, let me show you how it's done."
Pilot,2,Michael," Yes, I'd like to speak to your office manager, please. Yes, hello. This is Michael Scott. I am the Regional Manager of Dunder Mifflin Paper Products. Just wanted to talk to you manager-a-manger.  All right. Done deal. Thank you very much, sir. You're a gentleman and a scholar. Oh, I'm sorry. OK. I'm sorry. My mistake.  That was a woman I was talking to, so... She had a very low voice. Probably a smoker, so...  So that's the way it's done."
Pilot,3,Michael,"I've, uh, I've been at Dunder Mifflin for 12 years, the last four as Regional Manager. If you want to come through here

# Convert .csv to a textual script for tokenization

### Meta tokens for the script text

In [None]:
SCENE_START = "<scene_start>"
SCENE_END = "<scene_end>"

SPEAKER_START = "<speaker_start>"
SPEAKER_END = "<speaker_end>"

LINE_START = "<line_start>"
LINE_END = "<line_end>"

SENT_START = "<sent_start>"
SENT_END = "<sent_end>"

NEWLINE = "<NEWLINE>"

In [None]:
import csv
import nltk
import string
from nltk.tokenize import sent_tokenize

nltk.download('punkt')

# Used to remove punctuation from strings
translator = str.maketrans('', '', string.punctuation)

with open("script.txt", "w") as out_fd:
    out_fd.write(SCENE_START + " ")

    with open("processed_data.csv") as in_fd:
        csv_reader = csv.DictReader(in_fd)

        scene = 1
        for row in csv_reader:
            if int(row["scene"]) > scene:
                scene = int(row["scene"])
                out_fd.write(SCENE_END + " " + SCENE_START + " ")

            out_fd.write(f"{SPEAKER_START} {row['speaker']} {SPEAKER_END} {LINE_START} ")
            
            # A line may have multiple sentences
            sentences = sent_tokenize(row['line'])
            for sentence in sentences:
                sentence = sentence.translate(translator)
                out_fd.write(f"{SENT_START} {sentence} {SENT_END} ")

            out_fd.write(LINE_END + " ")

    out_fd.write(SCENE_END + " ")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
# Checking the first 500 characters of the file
!head -c500 script.txt

<scene_start> <speaker_start> Michael <speaker_end> <line_start> <sent_start> All right Jim <sent_end> <sent_start> Your quarterlies look very good <sent_end> <sent_start> How are things at the library <sent_end> <line_end> <speaker_start> Jim <speaker_end> <line_start> <sent_start> Oh I told you <sent_end> <sent_start> I couldnt close it <sent_end> <sent_start> So <sent_end> <line_end> <speaker_start> Michael <speaker_end> <line_start> <sent_start> So youve come to the master for guidance <sent

In [None]:
import csv
import nltk
import string
from nltk.tokenize import sent_tokenize

nltk.download('punkt')

# Used to remove punctuation from strings
translator = str.maketrans('', '', string.punctuation)

with open("script_simple.txt", "w") as out_fd:
    with open("processed_data.csv") as in_fd:
        csv_reader = csv.DictReader(in_fd)

        scene = 1
        for row in csv_reader:
            if int(row["scene"]) > scene:
                scene = int(row["scene"])

            out_fd.write(f"{row['speaker']}: {row['line']}\n")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [None]:
# Checking the first 500 characters of the file
!head -c500 script_simple.txt

Michael: All right Jim. Your quarterlies look very good. How are things at the library?
Jim: Oh, I told you. I couldn't close it. So...
Michael: So you've come to the master for guidance? Is this what you're saying, grasshopper?
Jim: Actually, you called me in here, but yeah.
Michael: All right. Well, let me show you how it's done.
Michael:  Yes, I'd like to speak to your office manager, please. Yes, hello. This is Michael Scott. I am the Regional Manager of Dunder Mifflin Paper Products. Just w

# Transfomer model

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    Examples:
        >>> pos_encoder = PositionalEncoding(d_model)
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        Examples:
            >>> output = pos_encoder(x)
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    """Container module with an encoder, a recurrent or transformer module, and a decoder."""

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        try:
            from torch.nn import TransformerEncoder, TransformerEncoderLayer
        except BaseException as e:
            raise ImportError('TransformerEncoder module does not exist in PyTorch 1.1 or '
                              'lower.') from e
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.encoder.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def forward(self, src, has_mask=True):
        if has_mask:
            device = src.device
            if self.src_mask is None or self.src_mask.size(0) != len(src):
                mask = self._generate_square_subsequent_mask(len(src)).to(device)
                self.src_mask = mask
        else:
            self.src_mask = None

        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        return F.log_softmax(output, dim=-1)

In [None]:
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    epoch_loss = 0
    for src, tgt in dataloader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        output = model(src)
        loss = criterion(output, tgt)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(dataloader)

In [None]:
from torchtext.data.utils import get_tokenizer

tokenizer = get_tokenizer('basic_english')
with open("script.txt", 'r') as fd:
    data = fd.read()

tokens = tokenizer(data)

In [None]:
tokens[:10]

['<scene_start>',
 '<speaker_start>',
 'michael',
 '<speaker_end>',
 '<line_start>',
 '<sent_start>',
 'all',
 'right',
 'jim',
 '<sent_end>']

In [None]:
from torchtext.vocab import build_vocab_from_iterator

vocab = build_vocab_from_iterator([tokens])

In [None]:
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn.functional as F

# Define the dataset
class ScriptDataset(Dataset):
    def __init__(self, tokens, seq_len=50):
        self.tokens = tokens
        self.seq_len = seq_len
    
    def __len__(self):
        return len(self.tokens) - self.seq_len
    
    def __getitem__(self, idx):
        x = []
        for i in range(idx, idx + self.seq_len):
            token = self.tokens[i]
            x.append(vocab[token])

        y = vocab[self.tokens[idx + self.seq_len]]
        y = F.one_hot(torch.tensor(y), num_classes=len(vocab)).float()

        return torch.tensor(x), y

In [None]:
INPUT_SEQ_LEN = 50
POS_EMBED_SIZE = 256
BATCH_SIZE = 128
LEARNING_RATE = 0.01
EPOCHS = 1

In [None]:
dataset = ScriptDataset(tokens, seq_len=INPUT_SEQ_LEN)
# SUBSET_SIZE = 10_000
# dataset = Subset(dataset, list(range(SUBSET_SIZE)))
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
import torch.optim as optim

# Initialize the model with desired parameters
ntoken = len(vocab)
ninp = POS_EMBED_SIZE
nhead = 8
nhid = 128
nlayers = 3
dropout = 0.5
model = TransformerModel(ntoken, ninp, nhead, nhid, nlayers, dropout)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
model = model.to(device)

In [None]:
print(model)

TransformerModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=128, bias=True)
        (dropout): Dropout(p=0.5, inplace=False)
        (linear2): Linear(in_features=128, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.5, inplace=False)
        (dropout2): Dropout(p=0.5, inplace=False)
      )
    )
  )
  (encoder): Embedding(18862, 256)
  (decoder): Linear(in_features=256, out_features=18862, bias=True)
)


In [None]:
# Train the model
for epoch in range(EPOCHS):
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        # Get the inputs and labels
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        outputs = outputs[:, -1, :]

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Print statistics
        running_loss += loss.item()
        if i % 1000 == 999:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 1000))
            running_loss = 0.0

print('Finished training')

[1,  1000] loss: 5.354
[1,  2000] loss: 5.285
[1,  3000] loss: 5.327
[1,  4000] loss: 5.276
[1,  5000] loss: 5.311
[1,  6000] loss: 5.374
Finished training


In [None]:
def encode_text(text):
    tokens = tokenizer(text)
    return [vocab[token] for token in tokens]

In [None]:
def decode_text(indexes):
    return " ".join([vocab.get_itos()[idx] for idx in indexes])

In [None]:
def generate(seed_text, n_lines=10, temperature=0.9):
    generated_text = seed_text.split()
    seed_text = encode_text(seed_text)

    seed_text = torch.tensor(seed_text).unsqueeze(0).to(device)

    for i in range(100):
        output = model(seed_text, False).squeeze()
        word_weights = output.div(temperature).exp().cpu()
        word_idx = torch.multinomial(word_weights, 1)[-1]
        word_tensor = torch.Tensor([[word_idx]]).long().to(device)
        seed_text = torch.cat([seed_text, word_tensor], -1)

        word = vocab.get_itos()[word_idx.squeeze().tolist()]
        generated_text.append(word)

    return generated_text

In [None]:
def post_process_line(text):
    punctuation_to_attach_to_previous_word = ['.', ':', '!', ';', ')', ']', '?', ',', '%']
    for punctuation in punctuation_to_attach_to_previous_word:
        text = text.replace(' ' + punctuation, punctuation)
        
    punctuation_to_attach_to_following_word = ['[', '(', '$']
    for punctuation in punctuation_to_attach_to_following_word:
        text = text.replace(punctuation + ' ', punctuation)
        
    punctuation_to_attach_to_same_word = ["'", '-']
    for punctuation in punctuation_to_attach_to_same_word:
        text = text.replace(' ' + punctuation + ' ', punctuation)
        
    text = text.replace(NEWLINE.lower(), "\n")
    text = text.replace("\n ", "\n")
    
    return text

def post_process(text):
    output = ""
    idx = 0

    text = text.split(" ")

    try:
        while True:
            token = text[idx]
            
            if token == SCENE_START.lower():
                output += "\n=== SCENE START ==="
            elif token == SCENE_END.lower():
                output += "\n=== SCENE END ===\n"
            elif token == SPEAKER_START.lower():
                idx += 1; token = text[idx]
                output += f"\n{token}: "
            elif token == LINE_END.lower():
                output += "."
            elif token == SENT_START.lower():
                idx += 1; token = text[idx]

                line = ""
                while token != SENT_END.lower():
                    line += token + " "
                    idx += 1; token = text[idx]

                output += post_process_line(line)

            idx += 1
            
            if idx == len(text):
                break
    except IndexError:
        pass
        
    return output

In [None]:
import torch
from torch.nn.functional import softmax

def generate_text(model, tokenizer, seed_text, max_len=100, temperature=1.0, top_k=0):
    model.eval()
    device = next(model.parameters()).device

    encoded = encode_text(seed_text)
    input_ids = torch.tensor(encoded).to(device)
    with torch.no_grad():
        for i in range(max_len):
            outputs = model(input_ids, False).squeeze()
            logits = outputs.div(temperature).exp()
            filtered_logits = top_k_top_p_filtering(logits, top_k=top_k)
            probs = softmax(filtered_logits, dim=-1)
            prev = torch.multinomial(probs, num_samples=1)[-1]
            input_ids = torch.cat([input_ids, prev], -1)
            if vocab.get_itos()[prev.squeeze().tolist()] == SCENE_END:
                break

    generated_text = decode_text(input_ids.squeeze().tolist())
    return generated_text

def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    if top_k > 0:
        top_logits, top_indices = torch.topk(logits, top_k)
        logits = torch.scatter(torch.ones_like(logits) * filter_value, dim=-1, index=top_indices, src=top_logits)
    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits = torch.scatter(logits, dim=-1, index=indices_to_remove, value=filter_value)
    return logits


In [None]:
seed_speaker = 'Dwight'
seed_text = f"{SCENE_START} {SPEAKER_START} {seed_speaker} {SPEAKER_END} {LINE_START} {SENT_START}".lower()

generated_text = generate(seed_text)
print(post_process(generated_text))

AttributeError: ignored

In [None]:
seed_speaker = 'Dwight'
seed_text = f"{SCENE_START} {SPEAKER_START} {seed_speaker} {SPEAKER_END} {LINE_START} {SENT_START}".lower()
# seed_text = f"{seed_speaker} : "
generated_text = generate_text(model, tokenizer, seed_text, max_len=100, temperature=1.0, top_k=0)
print(post_process(generated_text))

RuntimeError: ignored

In [None]:
per_line = 5
tokens = generated_text
for i in range(0, len(tokens), per_line):
    print(" ".join(tokens[i:i+per_line]))

<scene_start> <speaker_start> dwight <speaker_end> <line_start>
<sent_start> so to so get
beet me whole <sent_start> <line_end>
<sent_start> hes <sent_end> <sent_start> from
<speaker_start> <speaker_start> <speaker_end> lets <scene_end>
some i <sent_end> <line_end> so
<sent_start> i a <sent_end> <speaker_start>
<speaker_end> <speaker_start> me <line_end> <speaker_start>
<sent_start> <sent_end> <sent_start> <line_start> jim
<sent_end> <sent_start> <line_start> all <sent_end>
darryl <speaker_start> <speaker_start> to <speaker_start>
not about <sent_end> to <line_start>
<sent_end> think right the <line_start>
<line_start> you kevin <sent_end> me
to <line_end> <speaker_start> see pam
<sent_end> <speaker_start> moment is i
feel <sent_end> lately or jim
<speaker_end> did the because dwight
did <sent_start> to <sent_end> a
<sent_end> these the <speaker_start> <sent_start>
all jim <speaker_start> <sent_end> i
<sent_start>
