In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import os, textwrap

project_dir = "/content/image_captioning"
os.makedirs(project_dir, exist_ok=True)
os.chdir(project_dir)

# ===== Write all uploaded files to disk =====
files = {
    "data_utils.py": """from PIL import Image
import os
from torch.utils.data import Dataset


class ImageCaptionDataset(Dataset):

    def __init__(self, images_dir, captions_file, transform=None):
        self.images_dir = images_dir
        self.transform = transform
        self.samples = []
        with open(captions_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split('\t')
                if len(parts) < 2:
                    continue
                fname, caption = parts[0], '\t'.join(parts[1:])
                img_path = os.path.join(images_dir, fname)
                if os.path.exists(img_path):
                    self.samples.append((img_path, caption))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, caption = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, caption


# collate function for DataLoader
import torch


def collate_fn(batch, processor, tokenizer, max_target_len=64):
    images, captions = zip(*batch)
    # processor expects list of PIL images
    pixel_values = processor(images=list(images), return_tensors='pt').pixel_values
    # tokenize captions
    tokenized = tokenizer(list(captions), return_tensors='pt', padding='max_length',
                          truncation=True, max_length=max_target_len)
    labels = tokenized.input_ids
    # replace pad token id's in labels by -100 so they are ignored by loss
    labels[labels == tokenizer.pad_token_id] = -100
    return {'pixel_values': pixel_values, 'labels': labels}


class HFImageCaptionDataset(Dataset):

    def __init__(self, hf_dataset, transform=None, image_column='image', caption_column='caption'):
        self.hf_dataset = hf_dataset
        self.transform = transform
        self.image_column = image_column
        self.caption_column = caption_column

    def __len__(self):
        return len(self.hf_dataset)

    def __getitem__(self, idx):
        item = self.hf_dataset[int(idx)]

        image = item.get(self.image_column) if isinstance(item, dict) else None

        # Normalize HF image representations into a PIL.Image
        pil_image = None
        if isinstance(image, Image.Image):
            pil_image = image
        elif isinstance(image, dict) and 'path' in image:
            pil_image = Image.open(image['path']).convert('RGB')
        elif isinstance(image, (bytes, bytearray)):
            from io import BytesIO

            pil_image = Image.open(BytesIO(image)).convert('RGB')
        elif isinstance(image, str) and os.path.exists(image):
            pil_image = Image.open(image).convert('RGB')
        else:
            # last resort: some HF datasets put images under item['image']['bytes'] or item['image'] is a PIL-like object
            try:
                pil_image = item[self.image_column]
                if isinstance(pil_image, (bytes, bytearray)):
                    from io import BytesIO

                    pil_image = Image.open(BytesIO(pil_image)).convert('RGB')
            except Exception:
                raise ValueError(f"Cannot read image for index {idx}; unexpected format: {type(image)}")

        if self.transform:
            img_out = self.transform(pil_image)
        else:
            img_out = pil_image

        # caption field fallback logic
        caption_parts = []
        if isinstance(item, dict):
            # For Flickr8k, captions are caption_0 to caption_4
            for i in range(5):
                key = f'caption_{i}'
                if key in item and item[key]:
                    cap = item[key]
                    if isinstance(cap, str):
                        caption_parts.append(cap.strip())
            # fallback to other fields
            if not caption_parts:
                for c in (self.caption_column, 'caption', 'sentence', 'sentences', 'text'):
                    if c in item:
                        cap = item[c]
                        if isinstance(cap, str):
                            caption_parts.append(cap.strip())
                        elif isinstance(cap, (list, tuple)):
                            caption_parts.extend([str(x).strip() for x in cap])
                        break
        caption = ' '.join(caption_parts).replace('\\n', ' ').strip()

        return img_out, caption

""",
    "train.py": """
import argparse
import os
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm.auto import tqdm
from data_utils import ImageCaptionDataset, collate_fn, HFImageCaptionDataset


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--data_dir', type=str, default='/content/drive/MyDrive/data', help='directory with images/ and captions.txt')
    p.add_argument('--output_dir', type=str, default='output')
    p.add_argument('--pretrained_model', type=str, default='nlpconnect/vit-gpt2-image-captioning')
    p.add_argument('--epochs', type=int, default=3)
    p.add_argument('--batch_size', type=int, default=8)
    p.add_argument('--lr', type=float, default=5e-5)
    p.add_argument('--max_target_len', type=int, default=30)
    p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    p.add_argument('--use_hf_dataset', action='store_true', help='Use Hugging Face dataset directly instead of local files')
    p.add_argument('--save_steps', type=int, default=100, help='Save checkpoint every N steps')
    p.add_argument('--resume_from', type=str, default=None, help='Path to checkpoint .pt file to resume from')
    p.add_argument('--auto_resume', action='store_true', help='Automatically resume from the latest checkpoint in output_dir')
    return p.parse_args()


def save_checkpoint(model, optimizer, epoch, out_dir, name='checkpoint', step=None):

    os.makedirs(out_dir, exist_ok=True)
    ckpt = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
    }
    if step is not None:
        ckpt['step'] = int(step)
    torch.save(ckpt, os.path.join(out_dir, f'{name}.pt'))


def main():
    args = parse_args()
    device = args.device

    if args.auto_resume and not args.resume_from:
        import glob
        step_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-step*.pt'))
        epoch_checkpoints = glob.glob(os.path.join(args.output_dir, 'checkpoint-epoch*.pt'))
        all_ckpts = step_checkpoints + epoch_checkpoints
        if all_ckpts:
            def get_num(name):
                if 'step' in name:
                    return int(name.split('step')[1].split('.')[0])
                elif 'epoch' in name:
                    return int(name.split('epoch')[1].split('.')[0]) * 10000  # prefer steps over epochs
                return 0
            latest = max(all_ckpts, key=get_num)
            args.resume_from = latest
            print(f'Auto-resuming from {latest}')

    # Load model + processor + tokenizer
    print('Loading model and processors...')
    model = VisionEncoderDecoderModel.from_pretrained(args.pretrained_model)
    processor = ViTImageProcessor.from_pretrained(args.pretrained_model)
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model)

    # GPT-2 doesn't have a pad token by default - set it to eos
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model.to(device)
    model.train()

    # dataset & dataloader
    if args.use_hf_dataset:
        from datasets import load_dataset
        hf_ds = load_dataset('jxie/flickr8k', split='train')
        dataset = HFImageCaptionDataset(hf_ds)
    else:
        images_dir = os.path.join(args.data_dir, 'images')
        captions_file = os.path.join(args.data_dir, 'captions.txt')
        dataset = ImageCaptionDataset(images_dir, captions_file)
    print(f'Dataset size: {len(dataset)}')
    coll = lambda batch: collate_fn(batch, processor, tokenizer, max_target_len=args.max_target_len)
    # If resuming from a checkpoint (especially mid-epoch) we disable shuffle to allow
    # skipping forward in the same deterministic order. Note: resuming with shuffle=True
    # may repeat or skip samples depending on RNG state.
    shuffle_dl = False if (args.resume_from or args.auto_resume) else True
    if not shuffle_dl:
        print('Resuming: dataloader shuffle disabled to allow deterministic resume ordering')
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=shuffle_dl, collate_fn=coll)

    optimizer = AdamW(model.parameters(), lr=args.lr)

    resume_step = 0
    if args.resume_from:
        ckpt = torch.load(args.resume_from, map_location=device)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        # if the checkpoint contains a step value, resume in the middle of that epoch
        if 'step' in ckpt:
            start_epoch = int(ckpt['epoch'])
            resume_step = int(ckpt['step'])
            print(f'Resuming from epoch {start_epoch}, step {resume_step}')
        else:
            start_epoch = int(ckpt['epoch']) + 1
            print(f'Resumed from end of epoch {ckpt["epoch"]}; starting epoch {start_epoch}')
    else:
        start_epoch = 1

    best_loss = float('inf')
    for epoch in range(start_epoch, args.epochs + 1):
        loop = tqdm(dataloader, desc=f'Epoch {epoch}')
        running_loss = 0.0
        # enumerate steps starting at 1 to match stored checkpoint 'step' semantics
        for step, batch in enumerate(loop, start=1):
            # If we're resuming inside this epoch, skip already-processed steps
            if epoch == start_epoch and resume_step and step <= resume_step:
                if step % 50 == 0:
                    # occasional progress print while skipping
                    print(f'Skipping previously processed step {step}/{resume_step}')
                continue
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            running_loss += loss.item()
            # loop.set_postfix(loss=loss.item())

            # save checkpoint every save_steps (step is 1-based)
            if step % args.save_steps == 0:
                save_checkpoint(model, optimizer, epoch, args.output_dir, f'checkpoint-step{step}', step=step)

        avg_loss = running_loss / len(dataloader)
        print(f'Epoch {epoch} average loss: {avg_loss:.4f}')

        # save checkpoint every epoch
        ckpt_dir = os.path.join(args.output_dir, f'checkpoint-epoch{epoch}')
        model.save_pretrained(ckpt_dir)
        processor.save_pretrained(ckpt_dir)
        tokenizer.save_pretrained(ckpt_dir)

        if avg_loss < best_loss:
            best_loss = avg_loss
            model.save_pretrained(os.path.join(args.output_dir, 'checkpoint-best'))
            processor.save_pretrained(os.path.join(args.output_dir, 'checkpoint-best'))
            tokenizer.save_pretrained(os.path.join(args.output_dir, 'checkpoint-best'))

    print('Training finished.')


if __name__ == '__main__':
    main()
""",
    "inference.py": """

import argparse
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image
import torch


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument('--model_dir', type=str, required=True)
    p.add_argument('--image_path', type=str, required=True)
    p.add_argument('--max_length', type=int, default=30)
    p.add_argument('--num_beams', type=int, default=4)
    p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    return p.parse_args()


def generate_caption(model, processor, tokenizer, image_path, device, max_length=30, num_beams=4):
    model.to(device)
    model.eval()
    image = Image.open(image_path).convert('RGB')
    pixel_values = processor(images=image, return_tensors='pt').pixel_values.to(device)

    with torch.no_grad():
        output_ids = model.generate(pixel_values, max_length=max_length, num_beams=num_beams)
    caption = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
    return caption


def main():
    args = parse_args()
    model = VisionEncoderDecoderModel.from_pretrained(args.model_dir)
    processor = ViTImageProcessor.from_pretrained(args.model_dir)
    tokenizer = AutoTokenizer.from_pretrained(args.model_dir)

    # ensure pad token exists
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    caption = generate_caption(model, processor, tokenizer, args.image_path, args.device,
                               max_length=args.max_length, num_beams=args.num_beams)
    print('Caption:', caption)


if __name__ == '__main__':
    main()
""",
    "prepare_dataset_from_hf.py": """
import argparse
import os
from datasets import load_dataset
from PIL import Image
from pathlib import Path


def save_image(item, out_path):
    img = item.get('image') if isinstance(item, dict) else None

    if img is None:
        # some datasets use different fields
        for k in ('img', 'picture', 'image_file'):
            if k in item:
                img = item[k]
                break

    if img is None:
        raise ValueError('No image field found in item')

    # img might already be a PIL.Image, or a dict with 'path', or bytes
    if isinstance(img, Image.Image):
        img.save(out_path)
    elif isinstance(img, dict) and 'path' in img:
        # copy by opening
        Image.open(img['path']).convert('RGB').save(out_path)
    elif isinstance(img, (bytes, bytearray)):
        from io import BytesIO

        Image.open(BytesIO(img)).convert('RGB').save(out_path)
    elif isinstance(img, str) and os.path.exists(img):
        Image.open(img).convert('RGB').save(out_path)
    else:
        # Try letting datasets library decode it via to_pil_image if available
        try:
            Image.fromarray(img).convert('RGB').save(out_path)
        except Exception as e:
            raise ValueError(f'Unsupported image format: {type(img)}') from e


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--output_dir', type=str, default='data')
    p.add_argument('--dataset', type=str, default='jxie/flickr8k')
    p.add_argument('--split', type=str, default='train', help='dataset split to use')
    p.add_argument('--limit', type=int, default=None, help='limit number of examples (for quick tests)')
    args = p.parse_args()

    out_dir = Path(args.output_dir)
    images_dir = out_dir / 'images'
    images_dir.mkdir(parents=True, exist_ok=True)
    captions_file = out_dir / 'captions.txt'

    print(f'Loading dataset {args.dataset} split={args.split}...')
    ds = load_dataset(args.dataset, split=args.split)

    total = len(ds)
    print(f'Dataset size: {total}')

    limit = args.limit or total

    with open(captions_file, 'w', encoding='utf-8') as fout:
        for i, item in enumerate(ds):
            if i >= limit:
                break

            # Build filename: use original file name if available, else index.jpg
            fname = None
            if isinstance(item, dict):
                # common fields
                for k in ('file_name', 'filename', 'image_id', 'img_id'):
                    if k in item:
                        fname = str(item[k])
                        break

            if not fname:
                # use index-based name
                fname = f'{i:08d}.jpg'

            img_out = images_dir / fname

            # save image data
            try:
                save_image(item, img_out)
            except Exception as e:
                # fallback: try reading item['image']['path'] if present
                path = None
                if isinstance(item, dict) and 'image' in item and isinstance(item['image'], dict):
                    path = item['image'].get('path')
                if path and os.path.exists(path):
                    Image.open(path).convert('RGB').save(img_out)
                else:
                    print(f'Warning: could not save image for index {i}: {e}; skipping')
                    continue

            # extract caption: try several fields
            caption_parts = []
            if isinstance(item, dict):
                # For Flickr8k, captions are caption_0 to caption_4
                for i in range(5):
                    key = f'caption_{i}'
                    if key in item and item[key]:
                        cap = item[key]
                        if isinstance(cap, str):
                            caption_parts.append(cap.strip())
                # fallback to other fields
                if not caption_parts:
                    for c in ('caption', 'sentence', 'sentences', 'text'):
                        if c in item and item[c] is not None:
                            cap = item[c]
                            if isinstance(cap, str):
                                caption_parts.append(cap.strip())
                            elif isinstance(cap, (list, tuple)):
                                caption_parts.extend([str(x).strip() for x in cap])

            if not caption_parts and 'sentences' in item and isinstance(item['sentences'], (list, tuple)):
                # some flickr datasets keep captions under 'sentences' as list-of-dicts
                for s in item['sentences']:
                    if isinstance(s, dict) and 'raw' in s:
                        caption_parts.append(s['raw'].strip())
                    elif isinstance(s, str):
                        caption_parts.append(s.strip(k))

            caption = ' '.join(caption_parts).replace('\\n', ' ').strip()

            fout.write(f"{fname}\t{caption}\n")

            if (i + 1) % 100 == 0:
                print(f'Saved {i+1}/{limit}')

    print('Done. Saved images to', images_dir)
    print('Wrote captions to', captions_file)


if __name__ == '__main__':
    main()
""",
    "requirements.txt": """transformers>=4.30.0
datasets>=2.10.0
torch>=1.13.0
torchvision
Pillow
tqdm
nltk
accelerate
sentencepiece"""
}

for name, content in files.items():
    with open(name, "w") as f:
        f.write(textwrap.dedent(content))

print("✅ Project files created:", os.listdir())

✅ Project files created: ['data_utils.py', 'requirements.txt', 'inference.py', 'prepare_dataset_from_hf.py', 'train.py']


In [None]:
!pip install -r requirements.txt



In [None]:
import shutil, os

drive_data_path = "/content/drive/MyDrive/data"
local_data_path = "data"

# Copy data folder (images + captions.txt)
if os.path.exists(drive_data_path):
    shutil.copytree(drive_data_path, local_data_path, dirs_exist_ok=True)
else:
    print("⚠️  data folder not found in Drive; check your path")


⚠️  data folder not found in Drive; check your path


In [None]:
!python train.py \
  --data_dir data \
  --output_dir output \
  --epochs 25 \
  --batch_size 4 \
  --save_steps 1500


2025-10-26 12:24:09.369000: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1761481449.389655   15860 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761481449.395931   15860 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1761481449.411656   15860 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761481449.411685   15860 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1761481449.411690   15860 computation_placer.cc:177] computation placer alr

In [None]:
!zip -r captions_results.zip /content/image_captioning/output
from google.colab import files
files.download("captions_results.zip")


  adding: content/image_captioning/output/ (stored 0%)
  adding: content/image_captioning/output/checkpoint-epoch21/ (stored 0%)
  adding: content/image_captioning/output/checkpoint-epoch21/config.json (deflated 63%)
  adding: content/image_captioning/output/checkpoint-epoch21/model.safetensors (deflated 7%)
  adding: content/image_captioning/output/checkpoint-epoch21/merges.txt (deflated 53%)
  adding: content/image_captioning/output/checkpoint-epoch21/vocab.json (deflated 59%)
  adding: content/image_captioning/output/checkpoint-epoch21/special_tokens_map.json (deflated 81%)
  adding: content/image_captioning/output/checkpoint-epoch21/tokenizer_config.json (deflated 57%)
  adding: content/image_captioning/output/checkpoint-epoch21/tokenizer.json (deflated 82%)
  adding: content/image_captioning/output/checkpoint-epoch21/generation_config.json (deflated 40%)
  adding: content/image_captioning/output/checkpoint-epoch21/preprocessor_config.json (deflated 47%)
  adding: content/image_cap