In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

In [None]:
# import dataset
thread_name = 'AmItheAsshole'
posts_num = 27920

df_short = pd.read_csv( f'{thread_name}_latest_{posts_num}_posts_short.csv' )
df_short.date = pd.to_datetime(df_short.date, format='%d-%m-%Y')

prompt = df_short.text.sample(n=1, random_state=102).str.replace('\n', ' ' ).values[0]

## COMMENT GENERATOR PyTorch

In [None]:
"""
An abstract class representing a Dataset. All datasets that represent a map from keys to data samples should subclass it. 
All subclasses should overwrite __getitem__(), supporting fetching a data sample for a given key. Subclasses could also 
optionally overwrite __len__(), which is expected to return the size of the dataset by many Sampler implementations and 
the default options of DataLoader.
"""

# Define a custom dataset class
class PostsDataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        post = str(self.df.iloc[index]["text"])
        comment = str(self.df.iloc[index]["first_comment"])

        post_encoded = self.tokenizer.encode_plus( post, 
                                                  add_special_tokens=True, 
                                                  padding="max_length", 
                                                  max_length=self.max_len,
                                                  return_tensors="pt", 
                                                  truncation=True,
                                                )

        comment_encoded = self.tokenizer.encode_plus( comment, 
                                                     add_special_tokens=True, 
                                                     padding="max_length", 
                                                     max_length=self.max_len,
                                                     return_tensors="pt", 
                                                     truncation=True,
                                                    )
        return {
            "input_ids": post_encoded["input_ids"].squeeze(),
            "attention_mask": post_encoded["attention_mask"].squeeze(),
            "labels": comment_encoded["input_ids"].squeeze(),
        }

In [None]:
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name, vocab_size=100000)
# add a padding token to the tokenizer
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)

# define the dataset and the data loader
dataset = PostsDataset(df_short, tokenizer, max_len=100)
data_loader = DataLoader(dataset, batch_size=10)

# set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# define the optimizer and the loss function
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

# train the model
for epoch in range(0,10):
    # reset loss value
    running_loss = 0
    
    model.train()
    
    for batch in data_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        outs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=True)

        loss = loss_fn(outs.logits.view(-1, tokenizer.vocab_size), labels.view(-1))

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {running_loss/len(data_loader)}")

In [None]:
# keep traininng the model
for epoch in range(10,20):
    
    model.train()
    
    for batch in data_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()

        outs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, return_dict=True)

        loss = loss_fn(outs.logits.view(-1, tokenizer.vocab_size), labels.view(-1))

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {running_loss/len(data_loader)}")

In [None]:
# qenerate comments for same post
length = 100 # number of words

model.eval()

with torch.no_grad():
    # get random post from the stack
    #prompt = posts_df.text.sample(n=1, random_state=102).values[0] # Extract the string (!)
    # encoding
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
    # get length of the post
    input_length = input_ids.shape[1]
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device)
    # GENERATE THE COMMENT: 
    comment = model.generate( input_ids=input_ids, attention_mask=attention_mask, max_length = input_length + length,
                                num_beams = 5,
                                no_repeat_ngram_size = 2,
                                early_stopping = False,
                                )
    
    generated_comment_torch = tokenizer.decode(comment[0], skip_special_tokens=True)
    print(f"Generated comment: {generated_comment_torch.replace(prompt, '')}")