# train_last_20_perc_layers 

So, in this notebook, we will freeze the first 80% of the layers (the embedding layer and the initial transformer blocks) since they likely capture basic features. Then, we will train the remaining 20% of the transformer layers along with the classification head.

In [None]:
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, Gemma3Model,  TrainingArguments, Trainer
from huggingface_hub import login
from dotenv import load_dotenv
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from tqdm import tqdm
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# NOTE: we are using the pretrained model ( the model prior to SFT) since we have our own dataset

load_dotenv()

HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
MODEL = "google/gemma-3-4b-pt"
SEED = 69

login(token=HUGGINGFACE_TOKEN)

In [None]:
# Lets get tha dataset
# For us the dataset will be 
raw_dataset = load_dataset("mteb/tweet_sentiment_extraction")
df_train = pd.DataFrame(raw_dataset['train'])
df_test = pd.DataFrame(raw_dataset['test'])

In [None]:
# each segment of text "tweet" has a class 0 (negative), 1 (neutral), or 2 (positive)
df_train['label'].unique()

array([1, 0, 2])

In [None]:
df_train

Unnamed: 0,id,text,label,label_text
0,cb774db0d1,"I`d have responded, if I were going",1,neutral
1,549e992a42,Sooo SAD I will miss you here in San Diego!!!,0,negative
2,088c60f138,my boss is bullying me...,0,negative
3,9642c003ef,what interview! leave me alone,0,negative
4,358bd9e861,"Sons of ****, why couldn`t they put them on t...",0,negative
...,...,...,...,...
26727,4eac33d1c0,wish we could come see u on Denver husband l...,0,negative
26728,4f4c4fc327,I`ve wondered about rake to. The client has ...,0,negative
26729,f67aae2310,Yay good for both of you. Enjoy the break - y...,2,positive
26730,ed167662a5,But it was worth it ****.,2,positive


In [None]:
# we need this to format the input so model can understand
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)

In [None]:
# test of the tokenizer
text = ['hello world', 'bobby like to eat pizza']
vec = tokenizer(text, padding=True)
print("encoding: ",vec)

print("decoding: ",tokenizer.batch_decode(vec['input_ids']))

encoding:  {'input_ids': [[0, 0, 0, 0, 2, 23391, 1902], [2, 236763, 13990, 1133, 531, 9039, 19406]], 'attention_mask': [[0, 0, 0, 0, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1]]}
decoding:  ['<pad><pad><pad><pad><bos>hello world', '<bos>bobby like to eat pizza']


In [None]:
# we jsut define this so be used with the 'dataset' map function so apply to the data
def tokenize_dataset(data):
    return tokenizer(data['text'], padding="max_length", truncation=True, max_length=128)

In [None]:
# apply tokanizeion to the dataset
dataset = raw_dataset.map(tokenize_dataset, batched=True)

In [None]:
int(len(dataset['train']) * 0.7)

18712

In [None]:
# shuffle the dataset and split into smaller part sow e can run on laptop
train = dataset['train'].shuffle(SEED).select(range(int(len(dataset['train']) * 0.7)))
dev = dataset['train'].shuffle(SEED).select(range(int(len(dataset['train']) * 0.7), len(dataset['train'])))

In [None]:
#make data into a tensor
X_train = torch.tensor(train['input_ids'])
y_train = F.one_hot(torch.tensor(train['label']), num_classes=3).float()
X_dev = torch.tensor(dev['input_ids'])
y_dev = F.one_hot(torch.tensor(dev['label']), num_classes=3).float()

X_train.shape, y_train.shape, X_dev.shape, y_dev.shape

(torch.Size([18712, 128]),
 torch.Size([18712, 3]),
 torch.Size([8020, 128]),
 torch.Size([8020, 3]))

In [None]:
train_dataset = TensorDataset(X_train, y_train)
dev_dataset = TensorDataset(X_dev, y_dev)

train_loader = DataLoader(train_dataset, batch_size=4)
dev_loader = DataLoader(dev_dataset, batch_size=4)

In [None]:
def check_gpu_memory():
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            print(f"\nGPU {i}:")
            print(f"  Allocated: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
            print(f"  Cached: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")
            print(f"  Total: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.2f} GB")


In [None]:
# Since we are using gemma we need to add on to the base model a classification head
# To do so we will import the base model then construct our model using output from the base model
baseModel = Gemma3Model.from_pretrained(MODEL, device_map='auto', 
                                        output_hidden_states=True, 
                                        attn_implementation="eager", 
                                        max_memory = {
                                        0: "20GiB",        # GPU 0 - more memory training
                                        1: "8GiB",        # GPU 1 - less of the model since it will have outpus and y 
                                        "cpu": "80Gib"
                                        }
                                        )

check_gpu_memory()

Loading checkpoint shards: 100%|██████████| 2/2 [00:06<00:00,  3.32s/it]
Some parameters are on the meta device because they were offloaded to the cpu.



GPU 0:
  Allocated: 7.93 GB
  Cached: 7.94 GB
  Total: 23.67 GB

GPU 1:
  Allocated: 7.74 GB
  Cached: 7.74 GB
  Total: 23.67 GB


In [None]:
total = 0
for group in baseModel.parameters():
    print(group.shape)
    total += 1

print(total)

torch.Size([1152, 3, 14, 14])
torch.Size([1152])
torch.Size([4096, 1152])
torch.Size([1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152])
torch.Size([1152])
torch.Size([4304, 1152])
torch.Size([4304])
torch.Size([1152, 4304])
torch.Size([1152])
torch.Size([1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152])
torch.Size([1152])
torch.Size([4304, 1152])
torch.Size([4304])
torch.Size([1152, 4304])
torch.Size([1152])
torch.Size([1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152, 1152])
torch.Size([1152])
torch.Size([1152])
torch.Size([1

In [None]:
# we do this to see how many attention layer there are 
for group in baseModel.named_modules():
    print(group)

('', Gemma3Model(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(4096, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (self_attn): SiglipAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm2): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_

In [None]:
# we see it has 33 attention layer so we will freeze the first 26

#this wont effect that mem taken up on the GPU but lets freze the firs 80% of layers and leave the reast to train
for param in baseModel.language_model.embed_tokens.parameters():
    param.requires_grad = False

max_layer_to_freeze = 26
for i, layer in enumerate(baseModel.language_model.layers):
    if i <= max_layer_to_freeze:
        for param in layer.parameters():
            param.requires_grad = False


In [None]:
# We do this so that we have more room on the gpus
baseModel.vision_tower  = baseModel.vision_tower.to("cpu")
for param in baseModel.vision_tower.parameters():
                param.requires_grad = False
for param in baseModel.multi_modal_projector.parameters():
    param.requires_grad = False


    
check_gpu_memory()


GPU 0:
  Allocated: 6.38 GB
  Cached: 7.94 GB
  Total: 23.67 GB

GPU 1:
  Allocated: 7.74 GB
  Cached: 7.74 GB
  Total: 23.67 GB


In [None]:
baseModel.config.output_hidden_states = True    
baseModel.config.use_cache = False      
baseModel.gradient_checkpointing_enable()     

In [None]:
class Gemma3Classifier(nn.Module):
    def __init__(self, bmodel, hiddensize, dropout=0.1):
        super().__init__()
        self.bmodel = bmodel
        self.dropout = nn.Dropout(dropout) 
        self.head = nn.Linear(hiddensize, 3).to('cuda:1')
        self.device_placement = True
    
    def forward(self, input_ids):
        out = self.bmodel(input_ids)
        hidden_state = out.hidden_states[-1]
        embeddings = hidden_state[:, -1, :]  

        embeddings = embeddings.to('cuda:1')

        logits = self.head(self.dropout(embeddings))

        return logits 

In [None]:
model = Gemma3Classifier(bmodel=baseModel, dropout=0.1, hiddensize=baseModel.config.text_config.hidden_size)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters() ,lr=0.0003)
lossi = []
devlossi = []
torch.cuda.empty_cache()

In [None]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable percentage: {100 * trainable_params / total_params:.2f}%")

Trainable parameters: 660,688,387
Total parameters: 4,300,087,155
Trainable percentage: 15.36%


In [None]:

accumulation_steps = 8 # 8 * 4(our small batch size due to mem constraints) = 32 new updates after
for epoch in tqdm(range(10)):
    model.train()

    loss_total = 0
    for i,  (X_train, y_train) in enumerate(train_loader):
        out = model(input_ids=X_train)
        y_train = y_train.to('cuda:1')
        loss = criterion(out, y_train)
        loss_total += loss.item() 

        loss = loss / accumulation_steps #since batch size is jsut 4
        loss.backward()

        # now we upadte every {accumulation_steps} 
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    # if not perfectl;y diviable the we have left over gradient we need to use
    if (i + 1) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()        

    lossi.append(loss_total / len(train_loader))

    model.eval()
    dev_loss_total = 0
    with torch.no_grad():
        for X_dev, y_dev in dev_loader:
            out = model(input_ids=X_dev)
            y_dev = y_dev.to('cuda:1')
            loss = criterion(out, y_dev)
            dev_loss_total += loss.item()

    devlossi.append(dev_loss_total / len(dev_loader))

  0%|          | 0/10 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


In [None]:
plt.plot(lossi, label="lossi")
plt.plot(devlossi, label="devlossi")
plt.legend()

In [None]:
ex = dataset['train'][4]
ex_text = ex['text']
ex_input = torch.tensor(ex['input_ids']).unsqueeze(dim=0)
ex_label = ex['label']

with torch.no_grad():
    pred = model(ex_input)

print(f'The test is: {ex_text}')
if ex_label == 0:
    print(f'The label is: [1, 0, 0]')
elif ex_label == 1:
    print(f'The label is: [0, 1, 0]')
else:
    print(f'The label is: [0, 0, 1]')

print(f'The pred is: {torch.softmax(pred, dim=1)}')