Loosely following:
 
https://www.datacamp.com/tutorial/fine-tuning-large-language-models

In [15]:
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, Gemma3ForCausalLM
from huggingface_hub import login
from dotenv import load_dotenv
import os
import torch
import torch.nn as nn

In [16]:
load_dotenv()

HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
MODEL = "google/gemma-3-4b-it"
SEED = 69
device = 'mps'

login(token=HUGGINGFACE_TOKEN)

In [17]:
raw_dataset = load_dataset("mteb/tweet_sentiment_extraction")
df = pd.DataFrame(raw_dataset['train'])

In [18]:
df.iloc[26730]

id                             ed167662a5
text           But it was worth it  ****.
label                                   2
label_text                       positive
Name: 26730, dtype: object

In [19]:
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)

In [20]:
text = ['hello world', 'bobby like to eat pizza']
vec = tokenizer(text, padding=True)
print("encoding: ",vec)

print("decoding: ",tokenizer.batch_decode(vec['input_ids']))

encoding:  {'input_ids': [[0, 0, 0, 0, 2, 23391, 1902], [2, 236763, 13990, 1133, 531, 9039, 19406]], 'attention_mask': [[0, 0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]}
decoding:  ['<pad><pad><pad><pad><bos>hello world', '<bos>bobby like to eat pizza']


In [21]:
def tokenize_dataset(data):
    return tokenizer(data['text'], padding="max_length", truncation=True, max_length=128)

In [22]:
dataset = raw_dataset.map(tokenize_dataset, batched=True)

Map: 100%|██████████| 3432/3432 [00:00<00:00, 36735.80 examples/s]


In [23]:
train = dataset['train'].shuffle(SEED).select(range(2))
test = dataset['test'].shuffle(SEED).select(range(2))

In [24]:
#since we are using gemma we need to def a model for seq classification

baseModel = Gemma3ForCausalLM.from_pretrained(MODEL, device_map="auto")

Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.72s/it]


In [25]:
baseModel.config.output_hidden_states = True          
baseModel.gradient_checkpointing_enable()     

In [26]:
class Gemma3Classifier(nn.Module):
    def __init__(self, bmodel, hiddensize, dropout=0.1):
        super().__init__()
        self.bmodel = bmodel
        self.dropout = dropout 
        self.head = nn.Linear(hiddensize, 3)
    
    def forward(self, input, attention_mask):
        out = self.bmodel(input, attention_mask)
        return out
        

In [27]:
model = Gemma3Classifier(bmodel=baseModel, dropout=0.1, hiddensize=baseModel.config.hidden_size)

In [28]:
out = model(input=torch.tensor(train['input_ids']).to(device), attention_mask = torch.tensor(train['attention_mask']).to(device))

In [31]:
out['logits'].shape

torch.Size([2, 128, 262208])

In [33]:
pred_id = torch.argmax(out['logits'], dim=-1)
pred_id

tensor([[114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560, 114560,
         114560, 114560, 114

In [51]:
print(train['text'])
res =  tokenizer.batch_decode(pred_id.cpu())[1]
res

Column([' awww, that`s cute.', 'On train with at least two gaggles of teenagers sitting & the commuters squished standing in the back...at least the teenagers let me sit'])


' UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIX UNIXTheThe travel a least  passengershoules of geese. across staring train areishing in. front middle.\n\n least the train are me have down'

In [52]:
res.strip(" UNIX ")

'TheThe travel a least  passengershoules of geese. across staring train areishing in. front middle.\n\n least the train are me have down'