<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Fine_Tuning_Pretrained_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup

# Load pre-trained model tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Example dataset (replace with your actual data)
texts = ["I love programming.", "Transformers are awesome!", "Hello, world!"]
labels = [1, 1, 0]  # Example binary labels

# Tokenize the dataset
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
labels = torch.tensor(labels)

# Dataset and DataLoader
dataset = torch.utils.data.TensorDataset(input_ids, attention_mask, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True)

# Initialize the BERT model
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
for param in model.bert.parameters():
    param.requires_grad = False  # Freeze BERT encoder

optimizer = AdamW(model.classifier.parameters(), lr=2e-5)

# Training setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss()

# Scheduler
total_steps = len(dataloader) * 3  # Assuming 3 epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Training loop
model.train()
for epoch in range(3):  # Number of epochs
    for batch in dataloader:
        optimizer.zero_grad()

        input_ids, attention_mask, labels = [x.to(device) for x in batch]

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        loss.backward()
        optimizer.step()
        scheduler.step()

    print(f"Epoch {epoch+1} completed.")

print("Fine-tuning completed!")