In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset

In [None]:
# Load model to test from hugging face
tokenizer = AutoTokenizer.from_pretrained("snunlp/KR-Medium", do_lower_case=False)
bert_model = AutoModel.from_pretrained("snunlp/KR-Medium")

In [None]:
# Simple seq2seq translation head for transfer learning
class TranslationHead(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(TranslationHead, self).__init__()
        self.encoder = nn.GRU(input_size, hidden_size, batch_first=True)
        self.decoder = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, input_ids, attention_mask):
        bert_output = bert_model(input_ids=input_ids, attention_mask=attention_mask)
        encoder_input = bert_output.last_hidden_state
        _, encoder_hidden = self.encoder(encoder_input)
        decoder_input = torch.zeros_like(encoder_hidden)
        decoder_hidden = encoder_hidden
        outputs = []

        # use teacher forcing
        for _ in range(encoder_input.size(1)):
            decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden)
            output_step = self.output_layer(decoder_output)
            outputs.append(output_step)
            decoder_input = output_step
        outputs = torch.stack(outputs, dim=1)

        return outputs

In [None]:
class KR_BERT_TranslationHead(nn.Module):
    def __init__(self, model, translation_head):
        super(KR_BERT_TranslationHead, self).__init__()
        self.bert_model = model
        self.translation_head = translation_head
        self.fc = nn.Linear(256, 400)

    def forward(self, input_ids, attention_mask):
        translation_output = self.translation_head(input_ids, attention_mask)
        translation_output = self.fc(translation_output)
        return translation_output

In [None]:
csv_path = "en_kr_data/train/ko2en_training_csv/ko2en_finance_1_training.csv" # finance translations
df = pd.read_csv(csv_path, sep=',')

# Tokenize and encode the Korean and English sentences
tokenized_inputs = tokenizer(df["한국어"].tolist(), return_tensors="pt", padding=True, truncation=True)
labels = tokenizer(df["영어"].tolist(), return_tensors="pt", padding=True, truncation=True)["input_ids"]


train_dataset = TensorDataset(tokenized_inputs["input_ids"], tokenized_inputs["attention_mask"], labels)
batch_size = 1
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
translation_head = TranslationHead(768, 256, 256)
model = KR_BERT_TranslationHead(bert_model, translation_head)

loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 5

for epoch in range(num_epochs):
    transfer_model.train()

    for batch in train_dataloader:
        input_ids, attention_mask, labels = batch
        print(input_ids.shape, attention_mask.shape, labels.shape)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        # Assuming logits is a 2D tensor and labels is a 1D tensor
        # Assuming logits is a 3D tensor [batch_size, sequence_length, hidden_size]
        logits = outputs

        # Flatten the logits and labels for calculating the loss
        labels_flat = labels.view(-1)
        print(logits_flat.shape, labels_flat.shape)

        # CrossEntropyLoss expects 2D logits and 1D labels
        loss = loss_function(logits_flat, labels_flat)


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()

    # model.eval()
    # with torch.no_grad():
    #     predictions = model(validation_input_ids, validation_attention_mask)
    #     predictions = [tokenizer.decode(pred) for pred in predictions]
    #     targets = [tokenizer.decode(target) for target in validation_target_ids]
    #     bleu_score_value = calculate_bleu_score(predictions, targets)

    #     print(f"Epoch {epoch}, BLEU Score: {bleu_score_value}")
    print(f"Epoch {epoch + 1}, Training Loss: {loss.item()}")