In [40]:
import pandas as pd 
import string
mbti = pd.read_csv('../Downloads/mbti_1.csv')

import torch
from torch.utils.data import Dataset

class MBTIDataset(Dataset):
    def __init__(self, data):
        self.texts, self.labels = zip(*data)
        unique = set(self.labels)
        self.labelmap = dict(zip(unique, range(len(unique))))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sentence, label = self.texts[idx], self.labels[idx]
        return self.labelmap[label], sentence  # Keep sentence as a string or tokenize output

def clean_sentence(sen: str) -> str:
    sen = ' '.join(word for word in sen.split() if 'http' not in word)
    return sen.strip().translate(str.maketrans('', '', string.punctuation))

data = []
for _, (typ, sens) in mbti.iterrows():
    data += [(sen, typ) for sen in map(clean_sentence, sens.split('|||')) if len(sen) > 10]
data = MBTIDataset(data)

In [47]:

from torch import nn 
import torch 
from typing import List 
from transformers import AutoModel, AutoTokenizer, AutoConfig

class MBTIModel(nn.Module):
    def __init__(self, pretrained: str, num_classes: int, device: str):
        super().__init__()
        self.model = AutoModel.from_pretrained(pretrained, return_dict=False)
        self.tkz = AutoTokenizer.from_pretrained(pretrained)
        self.dim = AutoConfig.from_pretrained(pretrained).hidden_size
        self.ffn = nn.Sequential(
            nn.Linear(self.dim, num_classes),
            nn.LeakyReLU(0.1)
        )
        self.criteria = nn.CrossEntropyLoss()
        self.device = device 
        self.to(device)
        
    def forward(self, texts: List[str]) -> torch.Tensor:
        x = self.tkz(texts, padding=True, truncation=True, padding_side='right', return_tensors='pt').to(self.device)
        x = self.model(input_ids=x.input_ids, attention_mask=x.attention_mask)[1][:, 0]
        return self.ffn(x)


    def loss(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        return self.criteria(logits, labels)
    
    @torch.no_grad()
    def eval(self, texts: List[str], labels: torch.Tensor) -> torch.Tensor:
        logits = self.forward(texts)
        return (labels == logits).sum()/len(texts)

model = MBTIModel(pretrained='bert-base-uncased', num_classes=16, device='cuda:0')

In [None]:
from torch.utils.data import DataLoader
from torch.optim import AdamW
train_ds = DataLoader(data, batch_size=100, shuffle=True)
opt = AdamW(model.parameters())

for epoch in range(10):
    for texts, labels in train_ds:
        labels = torch.tensor(labels)
        loss = model(texts)
        
        
        
