<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/_Code_Example_Building_a_Modular_Foundation_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from transformers import BertForMaskedLM, BertTokenizer, BertModel, AdamW
from torch.utils.data import Dataset, DataLoader

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a custom dataset
class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128, for_classification=False):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.for_classification = for_classification

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]["text"]
        encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()

        if self.for_classification:
            label = self.data[idx]["label"]
            return input_ids, attention_mask, label
        else:
            return input_ids, attention_mask

# Define the FoundationModel class for MLM
class FoundationModel(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(FoundationModel, self).__init__()
        self.model = BertForMaskedLM.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        return outputs.loss, outputs.logits

    def encode_text(self, texts, max_length=128):
        encoding = self.tokenizer(texts, padding=True, truncation=True,
                                  max_length=max_length, return_tensors="pt")
        return encoding["input_ids"], encoding["attention_mask"]

# Mask tokens for MLM
def mask_tokens(inputs, tokenizer, mlm_probability=0.15):
    labels = inputs.clone()
    probability_matrix = torch.full(labels.shape, mlm_probability)
    mask_token_id = tokenizer.mask_token_id
    special_tokens_mask = [
        tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
    ]
    probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    inputs[masked_indices] = mask_token_id
    return inputs, labels

# Pretrain the model for MLM
def pretrain_model(model, tokenizer, train_data, epochs=3, batch_size=32, learning_rate=5e-6):  # Reduced learning rate
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    train_dataset = TextDataset(train_data, tokenizer)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_dataloader:
            optimizer.zero_grad()
            input_ids, attention_mask = batch
            input_ids, labels = mask_tokens(input_ids, tokenizer)
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            loss, _ = model(input_ids, attention_mask, labels=labels)
            if torch.isnan(loss):  # Detect and handle nan
                print("Nan encountered! Skipping batch.")
                continue
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_dataloader)}")

# Define the ClassificationFoundationModel class for classification
class ClassificationFoundationModel(nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_labels=2):
        super().__init__()
        self.model = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.classifier = nn.Linear(self.model.config.hidden_size, num_labels)
        self.num_labels = num_labels

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state[:, 0, :])  # CLS token
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        return loss, logits

    def encode_text(self, texts, max_length=128):
        encoding = self.tokenizer(texts, padding=True, truncation=True,
                                  max_length=max_length, return_tensors="pt")
        return encoding["input_ids"], encoding["attention_mask"]

# Train the classification model
def train_classifier(model, train_data, epochs=3, batch_size=32, learning_rate=5e-5):
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    train_dataset = TextDataset(train_data, model.tokenizer, for_classification=True)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_dataloader:
            optimizer.zero_grad()
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            loss, _ = model(input_ids, attention_mask, labels=labels)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}, Loss: {total_loss / len(train_dataloader)}")

# Evaluate the classification model
def evaluate_model(model, test_data, batch_size=32):
    test_dataset = TextDataset(test_data, model.tokenizer, for_classification=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in test_dataloader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            _, logits = model(input_ids, attention_mask)
            predictions = torch.argmax(logits, dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total
    print(f"Accuracy: {accuracy:.2f}")

# Example usage
# Replace with actual data
train_data = [{"text": "example training sentence", "label": 0}]
test_data = [{"text": "example test sentence", "label": 0}]

# Pretraining the Foundation Model
model = FoundationModel().to(device)
pretrain_model(model, model.tokenizer, train_data)

# Training the Classification Model
classification_model = ClassificationFoundationModel().to(device)
train_classifier(classification_model, train_data)

# Evaluating the Classification Model
evaluate_model(classification_model, test_data)