# Dependencies

In [None]:
from transformers import  GPT2LMHeadModel, GPT2Tokenizer,AdamW
import pandas as pd
from torch.utils.data import Dataset , DataLoader
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split

# The Dataset

In [None]:
df = pd.read_csv("../input/mediumsearchdataset/Train.csv")
df

# Downloading and testing GPT2

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")
gpt2 = GPT2LMHeadModel.from_pretrained("gpt2-large")

In [None]:
tokenizer.pad_token = tokenizer.eos_token

### Before finetuning

In [None]:
prompt = tokenizer.encode("machine learning", max_length = 30 , padding = "max_length" , truncation = True , return_tensors = "pt")
output = gpt2.generate(prompt,do_sample = True, max_length = 100,top_k = 10, temperature = 0.8)
tokenizer.decode(output[0]  , skip_special_tokens = True)

# Dataset Generator

In [None]:
class TitleDataset(Dataset):
    def __init__(self,titles):
        self.tokenizer = tokenizer
        self.titles = titles
    
    def __len__(self):
        return len(self.titles)
    
    def __getitem__(self,index):
        title = self.titles[index]
        title_token = tokenizer.encode(title , max_length = 30 , padding = "max_length" , truncation = True, return_tensors = "pt").reshape(-1)
        return title_token

##### Sanity Check (To make sure we are sending the right input to the model)

In [None]:
dset = TitleDataset(df["post_name"].values)
title = next(iter(DataLoader(dset , batch_size = 1,shuffle = True)))
display(title)

In [None]:
x_train , x_test = train_test_split(df, test_size = 0.3 , random_state = 42)

# Lightning DataModule

In [None]:
class TitleDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.train = TitleDataset(x_train["post_name"].values )
        self.test = TitleDataset(x_test["post_name"].values )
        self.val = TitleDataset(x_test["post_name"].values)
    
    def train_dataloader(self):
        return DataLoader(self.train , batch_size = 1 , shuffle = True)
    def test_dataloader(self):
        return DataLoader(self.test , batch_size = 1 , shuffle = False)
    def val_dataloader(self):
        return DataLoader(self.val , batch_size = 1 , shuffle = False)

In [None]:
gpt2_model = gpt2
print("done")

# Lightning Model

In [None]:
class TitleGenerator(pl.LightningModule):
    def __init__(self):
        super().__init__()
        gpt2_model.train()
        self.neural_net = gpt2_model
        
    def forward(self,x):
        return self.neural_net(x , labels = x)
    
    def configure_optimizers(self):
        return AdamW(self.parameters(), 1e-4)
        
    def training_step(self,batch,batch_idx):
        x= batch
        output = self(x)
        return output.loss
    
    def test_step(self,batch,batch_idx):
        x= batch
        output = self(x)
        return output.loss
    
    def validation_step(self,batch,batch_idx):
        x= batch
        output = self(x)
        return output.loss

# Training

In [None]:
from pytorch_lightning import Trainer
model = TitleGenerator()
module = TitleDataModule()
trainer = Trainer(max_epochs = 8,gpus = 1)
trainer.fit(model,module)

# Testing and prediction

In [None]:
gpt2.state_dict = model.state_dict

In [None]:
raw_text = ["The" ,"machine Learning"  , "A" , "Data science" , "AI" , "A" , "The" , "Why" , "how"]
output_text = []
for x in raw_text:
    prompt = tokenizer.encode(x , return_tensors = "pt")
    output = gpt2.generate(prompt,do_sample = True, max_length = 100,top_k = 10, temperature = 0.8)
    output_text.append(tokenizer.decode(output[0] , skip_special_tokens = True))

In [None]:
display(output_text)