In [1]:
import torch
import torch.nn as nn
from transformers import DistilBertModel, DistilBertTokenizer

# ----------------------
# 🔹 LSTM Model
# ----------------------
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.embedding(x)
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])  # Take last output for classification
        return self.softmax(out)

# ----------------------
# 🔹 Transformer (BERT) Model
# ----------------------
class BERTClassifier(nn.Module):
    def __init__(self, num_labels):
        super(BERTClassifier, self).__init__()
        self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.fc = nn.Linear(self.bert.config.hidden_size, num_labels)
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token representation
        return self.fc(cls_output)

# ----------------------
# 🔹 Tokenization Example (For BERT)
# ----------------------
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
text = "I love deep learning!"
tokens = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
print("Tokenized Input:", tokens)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Tokenized Input: {'input_ids': tensor([[ 101, 1045, 2293, 2784, 4083,  999,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}
