# Dependencies

In [None]:
# Install dependencies

!pip install -q transformers datasets torch scikit-learn wandb google-cloud-storage fsspec
!pip install -U datasets pyarrow fsspec

# Fine-tune BERT

In [3]:
# =================================
# 1. SETUP & AUTHENTICATION
# =================================
import os
import json
import logging
from pathlib import Path
from google.colab import files
from google.cloud import storage
from datasets import load_from_disk, DatasetDict

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Upload Google Cloud service account key
print("Upload your service account key JSON file...")
uploaded = files.upload()
KEY_FILE = list(uploaded.keys())[0]

# Set GCP credentials
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = KEY_FILE

# Define GCP parameters
PROJECT_ID = "gmail-sorter-457417"
BUCKET_NAME = "gmail-sorter-project"
GCS_TOKENIZED_DATA_PATH = f"gs://{BUCKET_NAME}/tokenized_data/"
GCS_CONFIG_FILE_PATH = f"gs://{BUCKET_NAME}/config/labels.json"
GCS_MODEL_OUTPUT_PATH = f"gs://{BUCKET_NAME}/model_artifacts/bert-finetuned-v1"

# =================================
# 2. LOAD DATA FROM GOOGLE CLOUD STORAGE
# =================================
def download_gcs_folder(bucket_name, source_prefix, local_dir):
    """Download all files in a GCS folder (prefix) to local storage."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=source_prefix)

    os.makedirs(local_dir, exist_ok=True)
    for blob in blobs:
        if blob.name.endswith('/'):
            continue
        relative_path = blob.name[len(source_prefix):].lstrip('/')
        local_path = os.path.join(local_dir, relative_path)
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        logger.info(f"Downloading {blob.name} to {local_path}")
        blob.download_to_filename(local_path)

def download_gcs_file(bucket_name, source_path, local_path):
    """Download a single file from GCS to local."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(source_path.replace(f"gs://{bucket_name}/", ""))
    os.makedirs(os.path.dirname(local_path), exist_ok=True)
    logger.info(f"Downloading {source_path} to {local_path}")
    blob.download_to_filename(local_path)
    with open(local_path, 'r') as f:
        return json.load(f)

# Download and load tokenized dataset
print("Downloading tokenized dataset…")
download_gcs_folder(BUCKET_NAME, "tokenized_data/", "./tokenized_data")

print("Loading dataset splits…")
base_path = Path("./tokenized_data").resolve()
tokenized_datasets = DatasetDict({
    "train": load_from_disk(f"file://{base_path}/train"),
    "validation": load_from_disk(f"file://{base_path}/validation"),
    "test": load_from_disk(f"file://{base_path}/test"),
})
print("Dataset splits loaded:", {k: v.num_rows for k, v in tokenized_datasets.items()})

# Debug: list local files
print("Files in ./tokenized_data:")
for root, _, files in os.walk("./tokenized_data"):
    for name in files:
        print(" ", os.path.join(root, name))

# Load label schema
print("Loading label schema...")
labels_json = download_gcs_file(BUCKET_NAME, GCS_CONFIG_FILE_PATH, "./labels.json")
all_labels = labels_json["labels"]
num_labels = len(all_labels)
print(f"Loaded {num_labels} labels: {all_labels}")

# === LIMIT DATASET FOR QUICK TESTING ===
# print("Subsetting datasets for quick test...")

# MAX_SAMPLES = 100  # adjust this as needed

# tokenized_datasets["train"] = tokenized_datasets["train"].select(range(min(MAX_SAMPLES, len(tokenized_datasets["train"]))))
# tokenized_datasets["validation"] = tokenized_datasets["validation"].select(range(min(20, len(tokenized_datasets["validation"]))))
# tokenized_datasets["test"] = tokenized_datasets["test"].select(range(min(20, len(tokenized_datasets["test"]))))

# print("Subset sizes:", {k: v.num_rows for k, v in tokenized_datasets.items()})

# =================================
# 3. MODEL TRAINING
# =================================
from transformers import (
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score
import numpy as np
import torch
import wandb

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

wandb.login()
wandb.init(project="gmail-sorter")

# Metric for evaluation
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {"accuracy": accuracy_score(labels, preds)}

# Load pre-trained BERT model
checkpoint = "bert-base-uncased"
print(f"Initializing model from checkpoint: {checkpoint}")
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint, num_labels=num_labels
).to(device)

# Define training arguments
LOCAL_OUTPUT_DIR = "./results"
training_args = TrainingArguments(
    output_dir=LOCAL_OUTPUT_DIR,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    report_to="wandb",
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

# Train the model
print("Starting training...")
trainer.train()
trainer.save_model(LOCAL_OUTPUT_DIR)
print(f"Model saved locally to: {LOCAL_OUTPUT_DIR}")

# =================================
# 4. UPLOAD MODEL TO GCS
# =================================
def upload_to_gcs(bucket_name, local_path, gcs_prefix):
    """Upload all files in a directory to GCS."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    local_path = Path(local_path)

    for file in local_path.rglob("*"):
        if file.is_file():
            gcs_object = f"{gcs_prefix}/{file.relative_to(local_path).as_posix()}"
            logger.info(f"Uploading {file} to gs://{bucket_name}/{gcs_object}")
            bucket.blob(gcs_object).upload_from_filename(str(file))

print("Uploading model artifacts to GCS...")
upload_to_gcs(BUCKET_NAME, LOCAL_OUTPUT_DIR, "model_artifacts/bert-finetuned-v1")
print(f"Artifacts uploaded to: {GCS_MODEL_OUTPUT_PATH}")

# =================================
# 5. READY FOR DEPLOYMENT
# =================================
print("✅ Model is ready for deployment on Vertex AI.")
print(f"Model artifacts location: {GCS_MODEL_OUTPUT_PATH}")
print("Use this path in Vertex AI > Models to create a custom model deployment.")

Collecting fsspec
  Using cached fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Upload your service account key JSON file...


Saving gmail-sorter-457417-373253ca841a.json to gmail-sorter-457417-373253ca841a (2).json
Downloading tokenized dataset…




Loading dataset splits…
Dataset splits loaded: {'train': 5359, 'validation': 596, 'test': 662}
Files in ./tokenized_data:
  ./tokenized_data/dataset_dict.json
  ./tokenized_data/train/state.json
  ./tokenized_data/train/dataset_info.json
  ./tokenized_data/train/data-00000-of-00001.arrow
  ./tokenized_data/test/state.json
  ./tokenized_data/test/dataset_info.json
  ./tokenized_data/test/data-00000-of-00001.arrow
  ./tokenized_data/validation/state.json
  ./tokenized_data/validation/dataset_info.json
  ./tokenized_data/validation/data-00000-of-00001.arrow
Loading label schema...
Loaded 51 labels: ['A-Levels', 'Ameen', 'apt search', 'Ayra', 'Banking & Investments/Al Meezan', 'Banking & Investments/Leads', 'Banking & Investments/Meezan', 'Banking & Investments/RBC', 'Banking & Investments/Scotia', 'Banking & Investments/Support', 'Banking & Investments/TD', 'Banking & Investments/Walmart', 'Banking & Investments/Wealthsimple', 'Banking & Investments/Wise', 'Bookings/Communauto', 'Bookings

Initializing model from checkpoint: bert-base-uncased


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training...


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.181556,0.766779
2,1.730200,0.610748,0.904362
3,0.540000,0.396403,0.922819
4,0.540000,0.306657,0.939597
5,0.218500,0.284562,0.942953
6,0.107300,0.24417,0.949664
7,0.107300,0.234403,0.954698
8,0.062100,0.228768,0.949664
9,0.045000,0.223871,0.954698


Model saved locally to: ./results
Uploading model artifacts to GCS...
Artifacts uploaded to: gs://gmail-sorter-project/model_artifacts/bert-finetuned-v1
✅ Model is ready for deployment on Vertex AI.
Model artifacts location: gs://gmail-sorter-project/model_artifacts/bert-finetuned-v1
Use this path in Vertex AI > Models to create a custom model deployment.


# Resume training

In [3]:
# =================================
# 1. SETUP & AUTHENTICATION
# =================================
import os
import json
import logging
from pathlib import Path
from google.colab import files
from google.cloud import storage
from datasets import load_from_disk, DatasetDict

# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Upload Google Cloud service account key
print("Upload your service account key JSON file...")
uploaded = files.upload()
KEY_FILE = list(uploaded.keys())[0]

# Set GCP credentials
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = KEY_FILE

# Define GCP parameters
PROJECT_ID = "gmail-sorter-457417"
BUCKET_NAME = "gmail-sorter-project"
GCS_TOKENIZED_DATA_PATH = f"gs://{BUCKET_NAME}/tokenized_data/"
GCS_CONFIG_FILE_PATH = f"gs://{BUCKET_NAME}/config/labels.json"
GCS_MODEL_OUTPUT_PATH = f"gs://{BUCKET_NAME}/model_artifacts/bert-finetuned-v1"

# =================================
# 2. LOAD DATA FROM GOOGLE CLOUD STORAGE
# =================================
def download_gcs_folder(bucket_name, source_prefix, local_dir):
    """Download all files in a GCS folder (prefix) to local storage."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=source_prefix)

    os.makedirs(local_dir, exist_ok=True)
    for blob in blobs:
        if blob.name.endswith('/'):
            continue
        relative_path = blob.name[len(source_prefix):].lstrip('/')
        local_path = os.path.join(local_dir, relative_path)
        os.makedirs(os.path.dirname(local_path), exist_ok=True)
        logger.info(f"Downloading {blob.name} to {local_path}")
        blob.download_to_filename(local_path)

def download_gcs_file(bucket_name, source_path, local_path):
    """Download a single file from GCS to local."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    blob = bucket.blob(source_path.replace(f"gs://{bucket_name}/", ""))
    os.makedirs(os.path.dirname(local_path), exist_ok=True)
    logger.info(f"Downloading {source_path} to {local_path}")
    blob.download_to_filename(local_path)
    with open(local_path, 'r') as f:
        return json.load(f)

# Download and load tokenized dataset
print("Downloading tokenized dataset…")
download_gcs_folder(BUCKET_NAME, "tokenized_data/", "./tokenized_data")

print("Loading dataset splits…")
base_path = Path("./tokenized_data").resolve()
tokenized_datasets = DatasetDict({
    "train": load_from_disk(f"file://{base_path}/train"),
    "validation": load_from_disk(f"file://{base_path}/validation"),
    "test": load_from_disk(f"file://{base_path}/test"),
})
print("Dataset splits loaded:", {k: v.num_rows for k, v in tokenized_datasets.items()})

# Debug: list local files
print("Files in ./tokenized_data:")
for root, _, files in os.walk("./tokenized_data"):
    for name in files:
        print(" ", os.path.join(root, name))

# Load label schema
print("Loading label schema...")
labels_json = download_gcs_file(BUCKET_NAME, GCS_CONFIG_FILE_PATH, "./labels.json")
all_labels = labels_json["labels"]
num_labels = len(all_labels)
print(f"Loaded {num_labels} labels: {all_labels}")

# === LIMIT DATASET FOR QUICK TESTING ===
# print("Subsetting datasets for quick test...")

# MAX_SAMPLES = 100  # adjust this as needed

# tokenized_datasets["train"] = tokenized_datasets["train"].select(range(min(MAX_SAMPLES, len(tokenized_datasets["train"]))))
# tokenized_datasets["validation"] = tokenized_datasets["validation"].select(range(min(20, len(tokenized_datasets["validation"]))))
# tokenized_datasets["test"] = tokenized_datasets["test"].select(range(min(20, len(tokenized_datasets["test"]))))

# print("Subset sizes:", {k: v.num_rows for k, v in tokenized_datasets.items()})
# ------------------------------------------------------------
# 3. MODEL TRAINING  (strict resume‑only mode)
# ------------------------------------------------------------
from transformers import (
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback
)
from sklearn.metrics import accuracy_score
import numpy as np
import torch
import wandb
from google.cloud import storage
from pathlib import Path
import re
import sys

# ---------- 3.1  WandB ------------------------------------------------------
wandb.login()
wandb.init(
    project="gmail-sorter",
    id="knl65f4t",          # previous run ID
    resume="allow",
    reinit=True
)

# ---------- 3.2  Helper to fetch latest checkpoint --------------------------
LATEST_PREFIX = "model_artifacts/bert-finetuned-v1/"

def fetch_latest_checkpoint(bucket_name, gcs_prefix, local_root="./results"):
    """
    Return the *local* path to the most recent checkpoint.
    If none exists, return None.
    """
    client = storage.Client()
    bucket = client.bucket(bucket_name)

    pattern = re.compile(rf"^{re.escape(gcs_prefix)}checkpoint-(\d+)/")
    candidates = {}

    # ⚠️  No delimiter -> walk entire prefix tree
    for blob in bucket.list_blobs(prefix=gcs_prefix):
        m = pattern.match(blob.name)
        if m:
            step = int(m.group(1))
            candidates[step] = f"{gcs_prefix}checkpoint-{step}/"

    if not candidates:
        return None

    step = max(candidates)
    remote_prefix = candidates[step]
    print(f"🡇  Latest checkpoint in GCS: {remote_prefix}")

    local_dir = Path(local_root) / f"checkpoint-{step}"
    from shutil import rmtree
    if local_dir.exists():
        rmtree(local_dir)        # clean stale copy
    download_gcs_folder(bucket_name, remote_prefix, str(local_dir))
    return str(local_dir)

# ---------- 3.3  Enforce resume‑only ---------------------------------------
LOCAL_OUTPUT_DIR = "./results"
local_checkpoint = fetch_latest_checkpoint(
    BUCKET_NAME, LATEST_PREFIX, LOCAL_OUTPUT_DIR
)

if local_checkpoint is None:
    print("❌  No checkpoint found – aborting (resume‑only mode).")
    sys.exit(1)      # stop the cell here

print(f"✅  Resuming from checkpoint: {local_checkpoint}")

# ---------- 3.4  Build model -----------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained(
    local_checkpoint,
    num_labels=num_labels
).to(device)

# ---------- 3.5  Training arguments ----------------------------------------
training_args = TrainingArguments(
    output_dir=LOCAL_OUTPUT_DIR,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=None,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    logging_strategy="epoch",
    report_to="wandb",
)

# ---------- 3.6  Trainer ----------------------------------------------------
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    return {"accuracy": accuracy_score(labels, np.argmax(logits, axis=-1))}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

# ---------- 3.7  Resume training -------------------------------------------
print("🚀  Continuing training …")
trainer.train(resume_from_checkpoint=True)

# ---------- 3.8  Save + upload as before -----------------------------------
trainer.save_model(LOCAL_OUTPUT_DIR)

# =================================
# 4. UPLOAD MODEL TO GCS
# =================================
def upload_to_gcs(bucket_name, local_path, gcs_prefix):
    """Upload all files in a directory to GCS."""
    client = storage.Client()
    bucket = client.bucket(bucket_name)
    local_path = Path(local_path)

    for file in local_path.rglob("*"):
        if file.is_file():
            gcs_object = f"{gcs_prefix}/{file.relative_to(local_path).as_posix()}"
            logger.info(f"Uploading {file} to gs://{bucket_name}/{gcs_object}")
            bucket.blob(gcs_object).upload_from_filename(str(file))

print("Uploading model artifacts to GCS...")
upload_to_gcs(BUCKET_NAME, LOCAL_OUTPUT_DIR, "model_artifacts/bert-finetuned-v1")
print(f"Artifacts uploaded to: {GCS_MODEL_OUTPUT_PATH}")

# =================================
# 5. READY FOR DEPLOYMENT
# =================================
print("✅ Model is ready for deployment on Vertex AI.")
print(f"Model artifacts location: {GCS_MODEL_OUTPUT_PATH}")
print("Use this path in Vertex AI > Models to create a custom model deployment.")

Collecting fsspec
  Using cached fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Upload your service account key JSON file...


Saving gmail-sorter-457417-373253ca841a.json to gmail-sorter-457417-373253ca841a (2).json
Downloading tokenized dataset…
Loading dataset splits…
Dataset splits loaded: {'train': 5359, 'validation': 596, 'test': 662}
Files in ./tokenized_data:
  ./tokenized_data/dataset_dict.json
  ./tokenized_data/train/state.json
  ./tokenized_data/train/dataset_info.json
  ./tokenized_data/train/data-00000-of-00001.arrow
  ./tokenized_data/test/state.json
  ./tokenized_data/test/dataset_info.json
  ./tokenized_data/test/data-00000-of-00001.arrow
  ./tokenized_data/validation/state.json
  ./tokenized_data/validation/dataset_info.json
  ./tokenized_data/validation/data-00000-of-00001.arrow
Loading label schema...




Loaded 51 labels: ['A-Levels', 'Ameen', 'apt search', 'Ayra', 'Banking & Investments/Al Meezan', 'Banking & Investments/Leads', 'Banking & Investments/Meezan', 'Banking & Investments/RBC', 'Banking & Investments/Scotia', 'Banking & Investments/Support', 'Banking & Investments/TD', 'Banking & Investments/Walmart', 'Banking & Investments/Wealthsimple', 'Banking & Investments/Wise', 'Bookings/Communauto', 'Bookings/Flights', 'Bookings/Hotel', 'Bookings/Transport', 'DOCS', 'Events & Tickets', 'Food & Shopping/Cancellations/Refunds', 'Food & Shopping/Delivery updates', 'Food & Shopping/loyalty & discounts', 'Food & Shopping/Receipts', 'Food & Shopping/Support', 'Government', 'Healthcare', 'Home/Fido', 'Home/Fizz', 'Home/Hydro', 'Home/Insurance', 'Home/Opus 6', 'Job Search/Applied', 'Job Search/Leads', 'Job Search/Next steps', 'Job Search/Rejected', 'Marium', 'McGill', 'Newsletters', 'Papa', 'Payments & Subscriptions/Action Required', 'Payments & Subscriptions/Cancellations & Refunds', 'Paym

0,1
eval/accuracy,0.9547
eval/loss,0.22387
eval/runtime,17.9962
eval/samples_per_second,33.118
eval/steps_per_second,2.112
total_flos,1.2695692340247552e+16
train/epoch,9.0
train/global_step,3015.0
train/grad_norm,0.49464
train/learning_rate,0.0


🡇  Latest checkpoint in GCS: model_artifacts/bert-finetuned-v1/checkpoint-3015/
✅  Resuming from checkpoint: results/checkpoint-3015
🚀  Continuing training …


Epoch,Training Loss,Validation Loss,Accuracy
10,0.0373,0.223527,0.95302


Could not locate the best model at ./results/checkpoint-2345/pytorch_model.bin, if you are running a distributed training on multiple nodes, you should activate `--save_on_each_node`.


Uploading model artifacts to GCS...
Artifacts uploaded to: gs://gmail-sorter-project/model_artifacts/bert-finetuned-v1
✅ Model is ready for deployment on Vertex AI.
Model artifacts location: gs://gmail-sorter-project/model_artifacts/bert-finetuned-v1
Use this path in Vertex AI > Models to create a custom model deployment.
