# Flickr8k Image Captioning Training (Colab Pro)

This notebook facilitates training the `mini-transformer` model on Google Colab.

### Professional Workflow:
1. **Mount Drive:** Saves trained model (`.pth`) and dataset zips persistently so we don't lose progress or have to re-download 1GB every time.
2. **Clone Repo:** Pulls latest code from GitHub into the Colab runtime.
3. **Hybrid Data Load:** Automatically detects if the dataset is on Google Drive; if not, it downloads it once and saves a copy.
4. **Install & Run:** Sets up the environment and starts the training script.

In [1]:
import os
import shutil
import sys
from google.colab import drive
from pathlib import Path

# 1. Mount Google Drive
drive.mount('/content/drive')

# 2. Setup Workspace (Clean and Clone)
## If repo is ever made private, this code will need to be changed to
## integrate a GitHub token to locally clone the repo
os.chdir('/content')
if os.path.exists('ImageDescript'):
    print("Removing old folder...")
    shutil.rmtree('ImageDescript')

print("Cloning public repository...")
!git clone https://github.com/aclink88/ImageDescript

# 3. Enter the project and setup packages
%cd ImageDescript
# Use mkdir -p to ensure the full path exists
!mkdir -p src/data src/model src/train
!touch src/__init__.py src/data/__init__.py src/model/__init__.py src/train/__init__.py

# 4. Setup Paths for Data Loading
GDRIVE_DATA_DIR = Path('/content/drive/MyDrive/trainingData')
LOCAL_DATA_DIR = Path('/content/ImageDescript/data')

GDRIVE_DATA_DIR.mkdir(parents=True, exist_ok=True)
LOCAL_DATA_DIR.mkdir(parents=True, exist_ok=True)

print("\nSetup complete! The full repo is now available in this Colab runtime.")

Mounted at /content/drive
Cloning public repository...
Cloning into 'ImageDescript'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (61/61), done.[K
remote: Compressing objects: 100% (46/46), done.[K
remote: Total 61 (delta 24), reused 46 (delta 13), pack-reused 0 (from 0)[K
Receiving objects: 100% (61/61), 21.30 KiB | 21.30 MiB/s, done.
Resolving deltas: 100% (24/24), done.
/content/ImageDescript

Setup complete! The full repo is now available in this Colab runtime.


In [2]:
## Sync Dataset from Drive or Source
def sync_data():
    zips = ['Flickr8k_Dataset.zip', 'Flickr8k_Text.zip']
    for zip_name in zips:
        gdrive_path = GDRIVE_DATA_DIR / zip_name
        local_zip_path = Path('/content') / zip_name
        if gdrive_path.exists():
            print(f"Found {zip_name} on Google Drive. Copying to local SSD...")
            !cp "{gdrive_path}" "{local_zip_path}"
        else:
            print(f"{zip_name} not found on Drive. Downloading from source...")
            url = f"https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/{zip_name}"
            !wget -O "{local_zip_path}" "{url}"
            print(f"Saving {zip_name} to Google Drive for future use...")
            !cp "{local_zip_path}" "{gdrive_path}"
        print(f"Extracting {zip_name} to local runtime disk...")
        !unzip -q -o "{local_zip_path}" -d "{LOCAL_DATA_DIR}"
        !rm "{local_zip_path}"

    # Handle the known typo in the original image zip folder name
    typo_dir = LOCAL_DATA_DIR / 'Flicker8k_Dataset'
    correct_dir = LOCAL_DATA_DIR / 'Flickr8k_Dataset'
    if typo_dir.exists() and not correct_dir.exists():
        typo_dir.rename(correct_dir)

sync_data()

Found Flickr8k_Dataset.zip on Google Drive. Copying to local SSD...
Extracting Flickr8k_Dataset.zip to local runtime disk...
Found Flickr8k_Text.zip on Google Drive. Copying to local SSD...
Extracting Flickr8k_Text.zip to local runtime disk...


In [3]:
# 5. Install Dependencies
!pip install spacy tqdm pandas Pillow torch torchvision
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [4]:
# Add current directory to PYTHONPATH
# This tells Python to look in /content/ImageDescript for the 'src' folder
# Add the project root to sys.path so Python can find the 'src' package
if '/content/ImageDescript' not in sys.path:
    sys.path.append('/content/ImageDescript')

## Run training
!PYTHONPATH=. python -m src.train.train
## Keeping below for future debugging
# Import the train function directly from the script
# This avoids all the -m flag path headaches
# try:
#     from src.train.train import train
#     print("Successfully imported the train function.")
# except ModuleNotFoundError as e:
#     print(f"Error: {e}")
#     print("\nCurrent directory contents (recursive):")
#     !ls -R
#     raise

# # Call the train function
# train()

Training on device: cuda
Loading data...
Initializing model...
Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100% 97.8M/97.8M [00:00<00:00, 205MB/s]
Starting training...

--- Epoch 1/10 ---
100% 1265/1265 [05:02<00:00,  4.18it/s, loss=3.39]

--- Epoch 2/10 ---
100% 1265/1265 [05:01<00:00,  4.19it/s, loss=3.3]

--- Epoch 3/10 ---
100% 1265/1265 [05:01<00:00,  4.19it/s, loss=2.57]

--- Epoch 4/10 ---
100% 1265/1265 [05:02<00:00,  4.18it/s, loss=2.74]

--- Epoch 5/10 ---
100% 1265/1265 [05:01<00:00,  4.19it/s, loss=3.21]

--- Epoch 6/10 ---
100% 1265/1265 [05:01<00:00,  4.20it/s, loss=2.76]

--- Epoch 7/10 ---
100% 1265/1265 [05:02<00:00,  4.18it/s, loss=2.63]

--- Epoch 8/10 ---
100% 1265/1265 [05:02<00:00,  4.19it/s, loss=2.23]

--- Epoch 9/10 ---
100% 1265/1265 [05:01<00:00,  4.19it/s, loss=2.2]

--- Epoch 10/10 ---
100% 1265/1265 [05:01<00:00,  4.19it/s, loss=2.57]

Training complete. Saving model..

In [5]:
# 7. Save trained model back to Google Drive
if os.path.exists('captioning_model.pth'):
    !mkdir -p /content/drive/MyDrive/savedModels
    !cp captioning_model.pth /content/drive/MyDrive/savedModels/captioning_model.pth
    print("Model checkpoint successfully backed up to Google Drive at /savedModels/")

Model checkpoint successfully backed up to Google Drive at /savedModels/
