In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn.functional as F
from torch import optim
import torch.nn as nn
import torch

import numpy as np
import random

# Get DialoGPT

In [None]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium").to("cuda")

In [None]:
tokenizer.pad_token_id = 0

In [None]:
text = "Hello there"
tokenized = tokenizer.encode_plus(text + tokenizer.eos_token,
                                  return_tensors="pt")['input_ids']
logits = model(tokenized.to("cuda"))[0]

In [None]:
logits

In [None]:
torch.argmax(logits, 2)[0][-1]

In [None]:
max_len = 10
for i in range(max_len):
    out = model(torch.tensor(tokenized).to("cuda"))
    new = torch.argmax(out[0], 2)[0][-1]
    tokenized = torch.unsqueeze(torch.concat((tokenized[0], torch.tensor([new]))), 0)

In [None]:
model.parameters

# Get Data

In [None]:
def readLangs():
    print("Reading lines...")

    conversations = [""]
    lines = ""
    counter = 0
    general = False

    with open("/media/nathanmon/389E28739E282BB6/Users/Natha/Datasets/MyJarvisConversation/conversation.txt", "r") as f:
        for line in f.readlines()[1:]:
            if line[0] == "G":
                conversations[counter] = lines[:-1].lower().replace("\n", " ")
                lines = ""
                conversations.append("")
                counter += 1
                
                general = True
            if line[0] == "C":
                conversations[counter] = lines[:-1].lower().replace("\n", " ")
                lines = ""
                conversations.append("")
                counter += 1
            if line[0] == "U":
                lines += line
            elif line[0] == "J":
                line = line.replace("/u", "/u ") + "/t "
                lines += line
                if general:
                    conversations[counter] = lines[:-1].lower().replace("\n", " ")
                    lines = ""
                    conversations.append("")
                    counter += 1

    return conversations

In [None]:
conversations = readLangs()[:-1]
np.random.shuffle(conversations)

In [None]:
len(conversations)

In [None]:
tokenizer.encode("sir /t user")

In [None]:
conversations[0]

# Train Model

In [None]:
def masked_loss(label, pred):
    mask = label != 0

    loss_object = nn.CrossEntropyLoss(ignore_index=0)
    
    pred_flat = pred.view(-1, pred.size(-1))
    label_flat = label.view(-1)
    
    pred_masked = pred_flat[mask.view(-1)]
    label_masked = label_flat[mask.view(-1)]
    
    loss = loss_object(pred_masked, label_masked)
    return loss


def masked_accuracy(label, pred):
    pred = torch.argmax(pred, axis=2)
    label = label.to(pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = match.to(torch.float32)
    mask = mask.to(torch.float32)
    return torch.sum(match)/torch.sum(mask)

In [None]:
keywords = ["/u shopping", "/u todolist", "/u wiki", "/u volume", "/a/"]
filenames = ["shopping_items", "todo_list_items", "wiki_queries", "volumes", "apps"]
augments = {"shopping_items": [], "todo_list_items": [],
            "wiki_queries": [], "volumes": [], "apps": []}

for keyword, filename in zip(keywords, filenames):
    with open(f"/media/nathanmon/389E28739E282BB6/Users/Natha/Datasets/MyJarvisConversation/{filename}.txt", "r") as f:
        for line in f.readlines():
            augments[filename].append(line.replace("\n", "").strip())

In [None]:
numbers = ["zero", "one", "two", "three", "four", 
           "five", "six", "seven", "eight", "nine",
           "ten", "eleven", "twelve", "thirteen",
           "fourteen", "fifteen", "sixteen", 
           "seventeen", "eighteen", "nineteen",
           "twenty", "thirty", "forty", "fifty",
           "sixty", "seventy", "eighty", "ninety",
           "hundred", "thousand", "million", "billion",
           "trillion", "quadrillion", "quintillion", "mute", "?"]

def sentence2num(sentence):
    words = sentence.split(" ")
    filtered = []
    for word in words:
        if word.lower() in numbers:
            filtered.append(word)

    return " ".join(filtered)

def find_tgt(response, loc="'"):
    try:
        lower = response.index(loc) + len(loc)
        upper = response[lower:].index(loc) + lower
        return response[lower:upper]
    except ValueError:
        return None
    
def augment(inp, tgt):
    keywords = ["/u shopping", "/u todolist", "/u wiki", "/u volume", "/a"]
    filenames = ["shopping_items", "todo_list_items", "wiki_queries", "volumes", "apps"]

    for keyword, filename in zip(keywords, filenames):
        if keyword in tgt or keyword in inp:
            if keyword == "/u volume":
                prev_item = sentence2num(find_tgt(tgt))
            elif keyword == "/a":
                prev_item = find_tgt(inp, "/a")[1:-1]
            else:
                prev_item = find_tgt(tgt)

            if keyword != "/uvolume" or (prev_item != "?"and prev_item.lower() != "mute"):
                replacement = random.choice(augments[filename])
                inp = inp.replace(prev_item, replacement)
                if keyword == "/a":
                    prev_item = find_tgt(tgt)[1:-1]

                if prev_item is not None:
                    tgt = tgt.replace(prev_item, replacement)
                
    return inp.replace("/a", ""), tgt.replace("/a", "")

def split_lines(line, sep):
    if line.find(sep) != -1:
        lines = []
        index = 0
        for i in range(line.count(sep)):
            lines.append(line[index:line.index(sep, index)+len(sep)])
            index = line.index(sep, index)+len(sep)+1
        return lines
    return [line]

def augment_tokens(tokenizer, tokens):
    string = tokenizer.decode(tokens)
    
    lines = split_lines(string, "/t")
    for i in range(len(lines)):
        if "\n" in lines[i]:
            split = lines[i].split("\n", 1)
            inp, tgt = augment(split[0], split[1])
            lines[i] = inp + "\n" + tgt
        lines[i] = lines[i]
    augmented = ''.join(lines)
    
    tokens = tokenizer.encode(augmented)
    return tokens

In [None]:
def train_epoch(conversations, model, optimizer, criterion, 
                window_length, scheduler, print_every, plot_every, train=True):
    global batch_size
    global tokenizer
    
    total_loss = 0
    plot_total_loss = 0  # Reset every plot_every
    plot_losses = []
    
    plot_learning_rates = []
    
    counter = 1
    num_tokens = 0
    
    for conversation in conversations:
        tokenized = tokenizer.encode(conversation + tokenizer.eos_token)
        print("Conversation ", counter)
        for i in range(2, len(tokenized)):
            segment = tokenized[max(i-window_length, 0):i]
            augmented = augment_tokens(tokenizer, segment)
            inp = augmented[:-1]
            tgt = augmented[-1]
                
            optimizer.zero_grad()

            logits = model(torch.tensor([inp]).to("cuda"))[0][0][-1]

            loss = criterion(logits, torch.tensor(tgt).to("cuda"))

            if train:
                loss.backward()

                optimizer.step()
                
            total_loss += loss.item()
            num_tokens += 1
                
        ######## metrics ########
        counter += 1
        if not train:
            scheduler.step(total_loss / num_tokens)

            if counter % print_every == 0:
                print_loss_avg = (total_loss / num_tokens) / print_every
                total_loss = 0
                print('Conversation  %d: %d%% %.4f %.7f' % (counter, 
                      counter / len(conversations) * 100, 
                      print_loss_avg, optimizer.param_groups[0]["lr"]))

        if counter % plot_every == 0:
            plot_loss_avg = (plot_total_loss / num_tokens) / plot_every
            plot_losses.append(plot_loss_avg)

            plot_total_loss = 0
            plot_learning_rates.append(optimizer.param_groups[0]["lr"])

    return plot_losses, plot_learning_rates

In [None]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's 
    validation loss is less than the previous least less, then save the
    model state.
    """
    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model, optimizer, criterion
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"Best validation loss: {self.best_valid_loss}")
            print(f"Saving best model for epoch: {epoch+1}")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, 'checkpoints/best_causal_model.pth')

In [None]:
# def train(train_dataloader, val_dataloader, transformer, n_epochs, learning_rate=0.001,
def train(train_conversations, val_conversations, model, n_epochs,
          window_length=25, learning_rate=1e-3, print_every=100, plot_every=100):
    start = time.time()
    global d_model
    plot_train_losses = []
    plot_val_losses = []
    
    plot_learning_rates = []

    save_best = SaveBestModel(best_valid_loss=.59)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate,
                           betas=(0.95, 0.9995), eps=1e-9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", factor=0.05, patience=300)        
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        plot_train_loss, _ = train_epoch(train_conversations, model, optimizer, criterion, 
                                                          window_length, scheduler, print_every, plot_every)
        
        plot_train_losses = np.concatenate((plot_train_losses, plot_train_loss))
        
        # Evaluate validation dataloader
        plot_val_loss, plot_learning_rate = train_epoch(val_conversations, model, optimizer, criterion, 
                                       window_length, scheduler, print_every, plot_every, train=False)
        plot_val_losses = np.concatenate((plot_val_losses, plot_val_loss))
        plot_learning_rates = np.concatenate((plot_learning_rates, plot_learning_rate))
        
        print('%s (%d %d%%) %.4f %.4' % (timeSince(start, epoch / n_epochs),
                epoch, epoch / n_epochs * 100))
        
        save_best(val_loss, epoch, model, optimizer, criterion)

    showPlot(plot_train_losses, "loss", plot_val_losses, "val_loss")
    showPlot(plot_learning_rates, "learning rate")
    return plot_train_losses

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points, points_name, points2=None, points2_name=None):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.5)
    ax.yaxis.set_major_locator(loc)
    if points2 != None:
        plt.plot(np.arange(len(points)), points, points2)
        plt.legend([points_name, points2_name])
    else:
        plt.plot(points)
        plt.legend([points_name])

In [None]:
len(conversations)

In [None]:
int(.9 * len(conversations))

In [None]:
torch.cuda.empty_cache()

In [None]:
batch_size = 32

train_conversations = conversations[:int(.9 * len(conversations))]
val_conversations = conversations[int(.9 * len(conversations)):]

history = train(train_conversations, val_conversations, model, 2, 
                window_length=65, learning_rate=1e-5, print_every=5, plot_every=5)

In [None]:
text = "Hello there"
tokenized = tokenizer.encode_plus(text + tokenizer.eos_token,
                                  return_tensors="pt")['input_ids']
logits = model(tokenized)[0]

In [None]:
max_len = 10
for i in range(max_len):
    out = model(torch.tensor(tokenized))
    new = torch.argmax(out[0], 2)[0][-1]
    tokenized = torch.unsqueeze(torch.concat((tokenized[0], torch.tensor([new]))), 0)