In [None]:
# Mount Google Drive for storing checkpoints
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install lightning

Collecting lightning
  Downloading lightning-2.5.1-py3-none-any.whl.metadata (39 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.14.2-py3-none-any.whl.metadata (5.6 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.7.0-py3-none-any.whl.metadata (21 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.1-py3-none-any.whl.metadata (20 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<4.0,>=2.1.0->lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Co

In [None]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from transformers import BlipProcessor, BlipForConditionalGeneration
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import json
import os
from PIL import Image
from torchvision import transforms

In [None]:
DATA_FOLDER = "/content/drive/MyDrive/Gemini_Captions"
TRAIN_IMAGES_FOLDER = os.path.join(DATA_FOLDER, "train")
TRAIN_CAPTIONS_FILE = os.path.join(DATA_FOLDER, "train.json")
VAL_IMAGES_FOLDER = os.path.join(DATA_FOLDER, "val")
VAL_CAPTIONS_FILE = os.path.join(DATA_FOLDER, "val.json")

In [None]:
CHECKPOINT_PATH = "/content/drive/MyDrive/gemini_models/blip_checkpoints/"
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

In [None]:
# Dataset class
class CaptionDataset(Dataset):
    def __init__(self, processor, images_folder, captions_file):
        with open(captions_file, 'r') as f:
            self.data = json.load(f)
        print(f"Loaded {len(self.data)} image-caption pairs from {captions_file}")

        self.processor = processor
        self.images_folder = images_folder
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(self.images_folder, item["filename"])
        image = Image.open(image_path).convert("RGB")

        pixel_values = self.transform(image)
        encoding = self.processor(
            text=item["description"],
            padding="max_length",
            return_tensors="pt"
        )

        return {
            "pixel_values": pixel_values,
            "input_ids": encoding.input_ids.squeeze(0),
            "attention_mask": encoding.attention_mask.squeeze(0),
            "labels": encoding.input_ids.squeeze(0)
        }


In [None]:
# Lightning Model with Freezing
class BlipLightning(pl.LightningModule):
    def __init__(self, model_name="Salesforce/blip-image-captioning-base", learning_rate=5e-5, freeze_vision=True, freeze_layers=6):
        super().__init__()
        self.model = BlipForConditionalGeneration.from_pretrained(model_name)
        self.processor = BlipProcessor.from_pretrained(model_name, use_fast=True)
        self.learning_rate = learning_rate

        if freeze_vision:
            for name, param in self.model.named_parameters():
                if "vision_model.embeddings" in name:
                    param.requires_grad = False
                if "vision_model.encoder.layers" in name:
                    parts = name.split(".")
                    try:
                        layer_index = int(parts[3])
                    except (IndexError, ValueError):
                        layer_index = None
                    if layer_index is not None and layer_index < freeze_layers:
                        param.requires_grad = False

    def forward(self, pixel_values, input_ids, attention_mask, labels):
        return self.model(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self(
            pixel_values=batch["pixel_values"],
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        loss = outputs.loss
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            pixel_values=batch["pixel_values"],
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"]
        )
        val_loss = outputs.loss
        self.log("val_loss", val_loss, prog_bar=True)
        print(f"Validation Loss [Batch {batch_idx}]: {val_loss.item()}")
        return val_loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)


In [None]:
# DataLoader function
def create_dataloader(images_folder, captions_file, batch_size=4, num_workers=2):
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", use_fast=True)
    dataset = CaptionDataset(processor, images_folder, captions_file)

    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [None]:
# Training function with checkpoint handling
def train_model(num_epochs=3, batch_size=4):
    train_dataloader = create_dataloader(TRAIN_IMAGES_FOLDER, TRAIN_CAPTIONS_FILE, batch_size)
    val_dataloader = create_dataloader(VAL_IMAGES_FOLDER, VAL_CAPTIONS_FILE, batch_size)

    # Initialize model
    model = BlipLightning()

    # Check for existing checkpoints
    checkpoint_files = [f for f in os.listdir(CHECKPOINT_PATH) if f.endswith(".ckpt")]
    latest_checkpoint = None
    if checkpoint_files:
        latest_checkpoint = os.path.join(CHECKPOINT_PATH, sorted(checkpoint_files)[-1])

    if latest_checkpoint:
        print(f"Resuming training from checkpoint: {latest_checkpoint}")
        model = BlipLightning.load_from_checkpoint(latest_checkpoint)

    # Trainer with automatic checkpoint saving
    checkpoint_callback = ModelCheckpoint(
        dirpath=CHECKPOINT_PATH,
        filename="blip-{epoch:02d}-{val_loss:.4f}",
        save_top_k=3,
        monitor="val_loss",
        mode="min",
        save_last=True
    )

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        accelerator="auto",
        deterministic=True,
        log_every_n_steps=25,
        callbacks=[checkpoint_callback]
    )

    trainer.fit(model, train_dataloader, val_dataloader)

    # Copy latest checkpoint to Drive for backup
    last_checkpoint = checkpoint_callback.last_model_path
    if last_checkpoint:
        os.system(f"cp {last_checkpoint} {CHECKPOINT_PATH}")
        print(f"✅ Copied last checkpoint to Google Drive: {last_checkpoint}")

In [None]:
if __name__ == "__main__":
    train_model(num_epochs=10, batch_size=8)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Loaded 10224 image-caption pairs from /content/drive/MyDrive/Gemini_Captions/train.json
Loaded 1000 image-caption pairs from /content/drive/MyDrive/Gemini_Captions/val.json


config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

Resuming training from checkpoint: /content/drive/MyDrive/gemini_models/blip_checkpoints/blip-epoch=04-val_loss=0.1560.ckpt


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /content/drive/.shortcut-targets-by-id/1oouePirCzzAPSiJ4fUW1mm3M7c-_gDnm/gemini_models/blip_checkpoints exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                         | Params | Mode
--------------------------------------------------------------
0 | model | BlipForConditionalGeneration | 247 M  | eval
--------------------------------------------------------------
203 M     Trainable params
43.6 M    Non-trainable params
247 M     Total params
989.656   Total estimated model params size (MB)
0         M

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:476: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Validation Loss [Batch 0]: 0.18632736802101135
Validation Loss [Batch 1]: 0.15381239354610443


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss [Batch 0]: 0.15508073568344116
Validation Loss [Batch 1]: 0.20424725115299225
Validation Loss [Batch 2]: 0.17950794100761414
Validation Loss [Batch 3]: 0.19617952406406403
Validation Loss [Batch 4]: 0.1817362755537033
Validation Loss [Batch 5]: 0.15741120278835297
Validation Loss [Batch 6]: 0.18574883043766022
Validation Loss [Batch 7]: 0.1585845798254013
Validation Loss [Batch 8]: 0.1610894501209259
Validation Loss [Batch 9]: 0.16459786891937256
Validation Loss [Batch 10]: 0.20075488090515137
Validation Loss [Batch 11]: 0.12196614593267441
Validation Loss [Batch 12]: 0.1445249766111374
Validation Loss [Batch 13]: 0.1532190889120102
Validation Loss [Batch 14]: 0.16380321979522705
Validation Loss [Batch 15]: 0.21657276153564453
Validation Loss [Batch 16]: 0.158913716673851
Validation Loss [Batch 17]: 0.1512705236673355
Validation Loss [Batch 18]: 0.16451102495193481
Validation Loss [Batch 19]: 0.16898499429225922
Validation Loss [Batch 20]: 0.18434041738510132
Validation

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss [Batch 0]: 0.17607788741588593
Validation Loss [Batch 1]: 0.16737107932567596
Validation Loss [Batch 2]: 0.2734419107437134
Validation Loss [Batch 3]: 0.16339905560016632
Validation Loss [Batch 4]: 0.2330542355775833
Validation Loss [Batch 5]: 0.20637014508247375
Validation Loss [Batch 6]: 0.16022032499313354
Validation Loss [Batch 7]: 0.1566184163093567
Validation Loss [Batch 8]: 0.17629201710224152
Validation Loss [Batch 9]: 0.1644604504108429
Validation Loss [Batch 10]: 0.16324613988399506
Validation Loss [Batch 11]: 0.1601719856262207
Validation Loss [Batch 12]: 0.16667069494724274
Validation Loss [Batch 13]: 0.20431312918663025
Validation Loss [Batch 14]: 0.1843603402376175
Validation Loss [Batch 15]: 0.1823350042104721
Validation Loss [Batch 16]: 0.1696109175682068
Validation Loss [Batch 17]: 0.21083681285381317
Validation Loss [Batch 18]: 0.1630956381559372
Validation Loss [Batch 19]: 0.16180874407291412
Validation Loss [Batch 20]: 0.19434340298175812
Validation 

Validation: |          | 0/? [00:00<?, ?it/s]

Validation Loss [Batch 0]: 0.2546345889568329
Validation Loss [Batch 1]: 0.1770365983247757
Validation Loss [Batch 2]: 0.2254430055618286
Validation Loss [Batch 3]: 0.19683821499347687
Validation Loss [Batch 4]: 0.19430473446846008
Validation Loss [Batch 5]: 0.18241046369075775
Validation Loss [Batch 6]: 0.17930781841278076
Validation Loss [Batch 7]: 0.21162772178649902
Validation Loss [Batch 8]: 0.2542683482170105
Validation Loss [Batch 9]: 0.17176620662212372
Validation Loss [Batch 10]: 0.19212445616722107
Validation Loss [Batch 11]: 0.20357218384742737
Validation Loss [Batch 12]: 0.16663868725299835
Validation Loss [Batch 13]: 0.2140263468027115
Validation Loss [Batch 14]: 0.2543252110481262
Validation Loss [Batch 15]: 0.20623603463172913
Validation Loss [Batch 16]: 0.2173757553100586
Validation Loss [Batch 17]: 0.23003408312797546
Validation Loss [Batch 18]: 0.22449782490730286
Validation Loss [Batch 19]: 0.23975275456905365
Validation Loss [Batch 20]: 0.20311038196086884
Validatio

INFO:pytorch_lightning.utilities.rank_zero:
Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined