In [1]:
import pandas as pd
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel
import torch
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer, RobertaModel
import math
import Model_Import_6
from torch import optim
import random

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
import nltk
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/ubuntu/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

In [4]:
def load_stuff(context_tensors, generator_logits, text_ids_loc):
    context = torch.load(context_tensors, map_location=lambda storage, loc: storage.cuda(0))
    logits = torch.load(generator_logits, map_location=lambda storage, loc: storage.cuda(0))
    text_ids = torch.load(text_ids_loc, map_location=lambda storage, loc: storage.cuda(0))
    return context, logits, text_ids

In [5]:
R_neg_embeds, neg_logits, neg_token_ids = load_stuff('R_neg_embeds.pt', 'neg_logits.pt', 'neg_token_ids.pt')

In [6]:
def train_one_style(model, optimizer, context_embeds_list, logits_list, token_ids_list, epochs, num_samples = 100):
    CELoss = nn.CrossEntropyLoss()
    total_count = 0
    for epoch in range(epochs):
        
        to_shuffle = list(zip(logits_list, token_ids_list))

        random.shuffle(to_shuffle)

        logits_list, token_ids_list = zip(*to_shuffle)

        model.train()
        ag_loss_epoch = 0
        epoch_count = 0
        for example in range(len(logits_list)):
            random_context_samples = random.sample(context_embeds_list, num_samples) # could use another context to see what happens
            stacked_context_sample = torch.stack(random_context_samples, dim = 0)
            # print(stacked_context_sample.shape)
            optimizer.zero_grad()
            network_output = model(stacked_context_sample.to(device), logits_list[example].to(device))
            if token_ids_list[example]['input_ids'].shape[1] == 1:
                print("ONE text id")
                continue
            shifted_network_output = network_output[..., :-1, :].contiguous()
            shifted_text_ids = token_ids_list[example]['input_ids'][..., 1:].contiguous().to(device)
            loss = CELoss(shifted_network_output.view(-1, shifted_network_output.size(-1)), shifted_text_ids.view(-1))
            ag_loss_epoch += loss
            epoch_count += 1
            total_count += 1
            loss.backward()
            optimizer.step()
        print(f"Epoch: {epoch}, Epoch Examples: {epoch_count}")
        print(f"TRAIN LOSS: {ag_loss_epoch / len(logits_list)}")
        # print(f"DEV LOSS: {full_dev_loss}")
        print("----------------------------------------")

In [9]:
neg_alone_model = Model_Import_6.Test_skip_norm_model(R_neg_embeds[0].shape[0], neg_logits[0].shape[1], attention_dim = None).to(device)
neg_optimizer = optim.Adam(neg_alone_model.parameters(), lr=0.00001,  weight_decay=0.001)

In [10]:
train_one_style(neg_alone_model, neg_optimizer, R_neg_embeds, neg_logits, neg_token_ids, 20)

KeyboardInterrupt: 

In [None]:
neg_alone_model = Model_Import_6.Test_skip_norm_model(R_neg_embeds[0].shape[0], neg_logits[0].shape[1], attention_dim = None).to(device)
neg_optimizer = optim.Adam(neg_alone_model.parameters(), lr=0.0001,  weight_decay=0.001)

In [None]:
train_one_style(neg_alone_model, neg_optimizer, R_neg_embeds, neg_logits, neg_token_ids, 20)

In [9]:
neg_alone_model_deeper = Model_Import_6.DeeperModel_skip(R_neg_embeds[0].shape[0], neg_logits[0].shape[1], attention_dim = None).to(device)
neg_optimizer = optim.Adam(neg_alone_model_deeper.parameters(), lr=0.00001,  weight_decay=0.001)

In [10]:
train_one_style(neg_alone_model_deeper, neg_optimizer, R_neg_embeds, neg_logits, neg_token_ids, 20)

Epoch: 0, Epoch Examples: 12500
TRAIN LOSS: 3.3413705825805664
----------------------------------------
Epoch: 1, Epoch Examples: 12500
TRAIN LOSS: 3.3046321868896484
----------------------------------------
Epoch: 2, Epoch Examples: 12500
TRAIN LOSS: 3.3020107746124268
----------------------------------------
Epoch: 3, Epoch Examples: 12500
TRAIN LOSS: 3.3004133701324463
----------------------------------------
Epoch: 4, Epoch Examples: 12500
TRAIN LOSS: 3.299299955368042
----------------------------------------
Epoch: 5, Epoch Examples: 12500
TRAIN LOSS: 3.2984564304351807
----------------------------------------
Epoch: 6, Epoch Examples: 12500
TRAIN LOSS: 3.2976303100585938
----------------------------------------
Epoch: 7, Epoch Examples: 12500
TRAIN LOSS: 3.297029972076416
----------------------------------------
Epoch: 8, Epoch Examples: 12500
TRAIN LOSS: 3.2964439392089844
----------------------------------------
Epoch: 9, Epoch Examples: 12500
TRAIN LOSS: 3.295955181121826
----

KeyboardInterrupt: 

In [7]:
neg_alone_model_wide = Model_Import_6.WiderModel_skip(R_neg_embeds[0].shape[0], neg_logits[0].shape[1], attention_dim = None).to(device) # 2 wide -> 1:46  4 wide ->3:00
neg_optimizer = optim.Adam(neg_alone_model_wide.parameters(), lr=0.00001,  weight_decay=0.001)

In [8]:
train_one_style(neg_alone_model_wide, neg_optimizer, R_neg_embeds, neg_logits, neg_token_ids, 20)

Epoch: 0, Epoch Examples: 12500
TRAIN LOSS: 3.967576503753662
----------------------------------------
Epoch: 1, Epoch Examples: 12500
TRAIN LOSS: 3.426697015762329
----------------------------------------
Epoch: 2, Epoch Examples: 12500
TRAIN LOSS: 3.3015451431274414
----------------------------------------
Epoch: 3, Epoch Examples: 12500
TRAIN LOSS: 3.2345831394195557
----------------------------------------
Epoch: 4, Epoch Examples: 12500
TRAIN LOSS: 3.191641092300415
----------------------------------------
Epoch: 5, Epoch Examples: 12500
TRAIN LOSS: 3.1612212657928467
----------------------------------------
Epoch: 6, Epoch Examples: 12500
TRAIN LOSS: 3.1387112140655518
----------------------------------------


KeyboardInterrupt: 

In [7]:
neg_alone_model_wide_deep = Model_Import_6.WiderDeeperModel_skip(R_neg_embeds[0].shape[0], neg_logits[0].shape[1], attention_dim = None).to(device) #
neg_optimizer = optim.Adam(neg_alone_model_wide_deep.parameters(), lr=0.00001,  weight_decay=0.001)

In [8]:
train_one_style(neg_alone_model_wide_deep, neg_optimizer, R_neg_embeds, neg_logits, neg_token_ids, 20)

Epoch: 0, Epoch Examples: 12500
TRAIN LOSS: 4.907911777496338
----------------------------------------
Epoch: 1, Epoch Examples: 12500
TRAIN LOSS: 3.9615514278411865
----------------------------------------


KeyboardInterrupt: 