In [None]:
import pandas as pd
import numpy as np
#pd.set_option('display.height', 1000)
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
pd.set_option('max_colwidth', 200)
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))


def df_stats(df):
    from tabulate import tabulate
    print("\n***** Shape: ", df.shape," *****\n")
    
    columns_list = df.columns.values.tolist()
    isnull_list = df.isnull().sum().values.tolist()
    isunique_list = df.nunique().values.tolist()
    dtypes_list = df.dtypes.tolist()
    
    list_stat_val = list(zip(columns_list, isnull_list, isunique_list, dtypes_list))
    df_stat_val = pd.DataFrame(list_stat_val, columns=['Name', 'Null', 'Unique', 'Dtypes'])
    print(tabulate(df_stat_val, headers='keys', tablefmt='psql'))
    return df.head()

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import wandb
# Commenting out the import of timm due to ModuleNotFoundError
# import timm
from transformers import AutoModel, AutoTokenizer
from tqdm import tqdm

import utils.data_processing.video
# Import the model classes and functions
from scripts.train_model import VideoEncoder, TextEncoder, VideoDataset, contrastive_loss, train_epoch

In [None]:
# Add to your imports cell
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import os

# Replace your setup cell with:
# Setup multi-GPU training
world_size = torch.cuda.device_count()
print(f"Found {world_size} GPUs")

# Initialize process group
if world_size > 1:
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", world_size=world_size, rank=0)

# Setup devices
devices = [torch.device(f'cuda:{i}') for i in range(world_size)]
print(f"Using devices: {devices}")

# Initialize models
video_encoder = VideoEncoder().to(devices[0])
text_encoder = TextEncoder().to(devices[0])

# Wrap models with DDP if using multiple GPUs
if world_size > 1:
    video_encoder = DDP(video_encoder, device_ids=[0])
    text_encoder = DDP(text_encoder, device_ids=[0])

# Create dataset
train_dataset = VideoDataset(
    root="data/processed/reports",
    data_filename="reports_sampled_1000.csv",
    split="train",
    target_label="Report",
    datapoint_loc_label="FileName",
    mean=[107.56801, 107.56801, 107.56801],
    std=[40.988625, 40.988625, 40.988625],
)

# Create dataloader with larger batch size for multi-GPU
batch_size = 24 * world_size  # Scale batch size with number of GPUs
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=6 * world_size,  # Scale workers with number of GPUs
    pin_memory=True
)

# Replace your training cell with:
# Training parameters
num_epochs = 2
learning_rate = 1e-4

# Optimizer
params = list(video_encoder.parameters()) + list(text_encoder.parameters())
optimizer = torch.optim.AdamW(params, lr=learning_rate)

# Create checkpoint directory
checkpoint_dir = Path('models/checkpoints')
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Note: train_epoch function needs to be called with rank=0 for notebook
    train_loss = train_epoch(video_encoder, text_encoder, train_dataloader, 
                           optimizer, devices[0], rank=0)
    
    print(f"Training loss: {train_loss:.4f}")
    
    # Save checkpoint
    if (epoch + 1) % 1 == 0:
        checkpoint = {
            'video_encoder': video_encoder.module.state_dict() if world_size > 1 else video_encoder.state_dict(),
            'text_encoder': text_encoder.module.state_dict() if world_size > 1 else text_encoder.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch
        }
        checkpoint_path = checkpoint_dir / f'checkpoint_epoch_{epoch+1}.pt'
        torch.save(checkpoint, checkpoint_path)
        print(f"Saved checkpoint for epoch {epoch+1} at {checkpoint_path}")

# Cleanup
if world_size > 1:
    dist.destroy_process_group()

In [4]:
def load_checkpoint(checkpoint_path, video_encoder, text_encoder, optimizer, world_size):
    checkpoint = torch.load(checkpoint_path)
    
    # Handle loading state dict for DataParallel models
    if world_size > 1 and not isinstance(video_encoder, nn.DataParallel):
        video_encoder = nn.DataParallel(video_encoder)
        text_encoder = nn.DataParallel(text_encoder)
    
    video_encoder.load_state_dict(checkpoint['video_encoder'])
    text_encoder.load_state_dict(checkpoint['text_encoder'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return checkpoint['epoch']

# Example: Load the last checkpoint
last_checkpoint = checkpoint_dir / f'checkpoint_epoch_{num_epochs}.pt'
if last_checkpoint.exists():
    epoch = load_checkpoint(last_checkpoint, video_encoder, text_encoder, optimizer, world_size)
    print(f"Loaded checkpoint from epoch {epoch+1}")

In [None]:
!python scripts/train_multi_gpu.py
