# SeaBot - Image

Train both the image model and the text generation model at the same time. Such that the embeddings of the video classification model are passed directly into the text generation model. Vanishing gradient could be a problem, I may have to do some skipped connections, make it a bit more shallow...

Open Question: Is the self-supervised fine-tuning process necessary?

Structure:

1. Train a generator to create instance of the distribution of annotation text. Unsupervised.
2. Train the video classification method on the distribution of the annotation imagery. Unsupervised.
3. Combine both methods into a singular pipeline and then use the actual annotations to derive results.

The Models:


The Data:
TODO: Remove redundant data entries "CPHD."

	Canadian and Local Annotations

	Video

In [2]:
!pip install transformers
!pip install torch
!pip install pytorchvideo
!pip install ffmpeg-python
!pip install torchvision
!pip install tqdm
!pip install fathomnet
!pip install openai
!pip install wandb
!pip install pafy

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.2.1[0m[39;49m -> [0m[32;49m23.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpi

## Pretraining

In [3]:
import boto3
import ffmpeg
import glob
import numpy as np
import os
import random
import torch
import traceback
from PIL import Image
from random import randint
from sklearn.model_selection import train_test_split
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import ViTForImageClassification
from tqdm import tqdm

import hashlib

# Constants and Configurations
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
WANDB_KEY = '856878a46a17646e66281426d43c4b77d3f9a00c'
BUCKET_NAME = 'seabot-d2-storage'
NUM_EPOCHS = 5
PATIENCE = 2
SAVE_FREQ = 100
LOCAL_MODEL_DIR = 'local_models'
LOCAL_VIDEO_DIR = 'local_videos'
LOCAL_IMAGE_DIR = 'local_images'
DATASET_ROOT_PATH = "SeaBot/Data/EX2304_Compressed"
IMAGE_ROOT_PATH = "SeaBot/Data/Extracted_Images"
MODEL_ROOT_PATH = "SeaBot/Models"
TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Set random seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

# Import and setup wandb
import wandb
wandb.login(key=WANDB_KEY)
wandb.init(project="seabot", name="drive_pretraining_AWS_large_EX2304")

# Helper Functions
def create_directory(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

def check_s3_path(s3, bucket, path):
    bucket_obj = s3.Bucket(bucket)
    objects = list(bucket_obj.objects.filter(Prefix=path))
    if len(objects) > 0 and objects[0].key == path:
        return True
    else:
        return False

def download_from_s3(s3_client, bucket_name, s3_file_path, local_file_path):
    s3_client.download_file(bucket_name, s3_file_path, local_file_path)

def check_frame_exists_s3(s3_client, bucket_name, frame_s3_path):
    try:
        s3_client.head_object(Bucket=bucket_name, Key=frame_s3_path)
        return True
    except:
        return False

def generate_frame_hash(video_file_name, frame_number):
    # Create a unique hash for each frame using the video file name and frame number
    frame_id = f"{video_file_name}-{frame_number}"
    return hashlib.md5(frame_id.encode()).hexdigest()

def upload_frame_to_s3(local_path, s3_path, s3_client, bucket_name):
    try:
        s3_client.upload_file(local_path, bucket_name, s3_path)
    except Exception as e:
        print(f"Error occurred while uploading {local_path} to S3: {e}")

def extract_frames(bucket_name, video_s3_path, local_video_path, frame_rate=1):
    local_video_file = os.path.join(local_video_path, os.path.basename(video_s3_path))
    video_file_name = os.path.splitext(os.path.basename(local_video_file))[0]
    s3_client = boto3.client('s3')
    s3_client.download_file(bucket_name, video_s3_path, local_video_file)
    
    # Extract frames to LOCAL_IMAGE_DIR
    ffmpeg.input(local_video_file).filter('fps', fps=frame_rate).output(f"{LOCAL_IMAGE_DIR}/{video_file_name}_frame%03d.png").run()

    # Process each extracted frame
    frame_paths = glob.glob(f"{LOCAL_IMAGE_DIR}/{video_file_name}_frame*.png")
    for frame_path in frame_paths:
        frame_number = int(os.path.basename(frame_path)[len(video_file_name) + 6:-4])  # Extract frame number from filename
        frame_hash = generate_frame_hash(video_file_name, frame_number)
        s3_frame_path = f"{IMAGE_ROOT_PATH}/{frame_hash}.png"

        # Check if frame already exists in S3
        if not check_frame_exists_s3(s3_client, bucket_name, s3_frame_path):
            # Upload frame to S3 with error handling
            upload_frame_to_s3(frame_path, s3_frame_path, s3_client, bucket_name)
            # Optionally delete the frame from local storage after upload
        os.remove(frame_path)
            
def load_latest_checkpoint(model, optimizer, scheduler):
    try:
        checkpoints = glob.glob(os.path.join(LOCAL_MODEL_DIR, "*.pth"))
        checkpoints.sort(key=os.path.getmtime)
        if checkpoints:
            latest_checkpoint_path = checkpoints[-1]
            checkpoint = torch.load(latest_checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            return model, optimizer, scheduler, checkpoint['epoch'], checkpoint['batch'], checkpoint['best_loss']
        else:
            return model, optimizer, scheduler, 0, 0, np.inf
    except Exception as e:
        print(f"Exception occurred: {e}")
        traceback.print_exc()
        raise e

def save_model_to_s3(local_model_path, s3_model_path, bucket_name):
    s3_client = boto3.client('s3')
    try:
        s3_client.upload_file(local_model_path, bucket_name, s3_model_path)
        print(f"Model successfully uploaded to {s3_model_path} in bucket {bucket_name}")
    except Exception as e:
        print(f"Error occurred while uploading model to S3: {e}")

def train_loop(start_epoch, start_batch, best_loss, model, optimizer, scheduler, train_loader, val_loader, criterion):
    global PATIENCE
    no_improve_epoch = 0
    best_val_loss = best_loss

    for epoch in range(start_epoch, NUM_EPOCHS):
        model.train()
        total_loss = 0
        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs.logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # Wandb logging for training batch loss
            wandb.log({"epoch": epoch, "batch": batch_idx, "train_batch_loss": loss.item()})

            if (batch_idx + 1) % SAVE_FREQ == 0:
                checkpoint_path = os.path.join(LOCAL_MODEL_DIR, f'checkpoint_epoch_{epoch}_batch_{batch_idx}.pth')
                torch.save({
                    'epoch': epoch,
                    'batch': batch_idx,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_loss': best_val_loss,
                }, checkpoint_path)

        avg_train_loss = total_loss / len(train_loader)

        # Validation step
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for val_images, val_labels in val_loader:
                val_images, val_labels = val_images.to(DEVICE), val_labels.to(DEVICE)
                val_outputs = model(val_images)
                batch_loss = criterion(val_outputs.logits, val_labels)
                val_loss += batch_loss.item()

        avg_val_loss = val_loss / len(val_loader)

        # Wandb logging for average training and validation loss
        wandb.log({"epoch": epoch, "avg_train_loss": avg_train_loss, "avg_val_loss": avg_val_loss})

        # Check for improvement
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_path = os.path.join(LOCAL_MODEL_DIR, 'best_model.pth')
            torch.save(model.state_dict(), best_model_path)

            # Save to S3
            s3_model_path = os.path.join(MODEL_ROOT_PATH, 'best_model.pth')
            save_model_to_s3(best_model_path, s3_model_path, BUCKET_NAME)

            no_improve_epoch = 0
        else:
            no_improve_epoch += 1

        # Early stopping check
        if no_improve_epoch >= PATIENCE:
            print("Early stopping due to no improvement in validation loss.")
            break

        # Learning rate scheduler step
        scheduler.step()

        print(f'Epoch {epoch+1}: Training Loss: {avg_train_loss}, Validation Loss: {avg_val_loss}')


def list_s3_files(bucket_name, prefix):
    s3_client = boto3.client('s3')
    files = []
    paginator = s3_client.get_paginator('list_objects_v2')
    
    # Iterate through each page of results
    for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix):
        if 'Contents' in page:
            for item in page['Contents']:
                if item['Key'].endswith('.mp4'):
                    files.append(item['Key'])

    return files

# Create necessary directories
create_directory(LOCAL_MODEL_DIR)
create_directory(LOCAL_VIDEO_DIR)
create_directory(LOCAL_IMAGE_DIR)

# Initialize S3 session
s3 = boto3.resource('s3')

# Check S3 paths
if check_s3_path(s3, BUCKET_NAME, DATASET_ROOT_PATH):
    print('The directory exists')
else:
    print('The directory does not exist')

# Define model, optimizer, scheduler, and criterion
model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-224')
model.classifier = nn.Linear(model.config.hidden_size, 4)
for param in model.parameters():
    param.requires_grad = True
model = model.to(DEVICE)
optimizer = optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.CrossEntropyLoss()

# Define the Dataset class
class ImageDataset(Dataset):
    def __init__(self, s3_client, bucket_name, image_keys, transform=None):
        self.s3_client = s3_client
        self.bucket_name = bucket_name
        self.image_keys = image_keys
        self.transform = transform

    def __len__(self):
        return len(self.image_keys)

    def __getitem__(self, idx):
        image_key = self.image_keys[idx]
        image_path = os.path.join(LOCAL_IMAGE_DIR, os.path.basename(image_key))
        self.s3_client.download_file(self.bucket_name, image_key, image_path)

        with Image.open(image_path) as img:
            rotation = randint(0, 3)
            rotated_image = img.rotate(rotation * 90)
            if self.transform:
                rotated_image = self.transform(rotated_image)

        # Delete the image from local storage after processing
        os.remove(image_path)

        return rotated_image, rotation

# Main script logic
if not os.path.isfile(LOCAL_MODEL_DIR):
    video_files = list_s3_files(BUCKET_NAME, DATASET_ROOT_PATH)
    for video_file in tqdm(video_files, desc='Extracting frames'):
        extract_frames(BUCKET_NAME, video_file, LOCAL_VIDEO_DIR)

    s3_client = boto3.client('s3')
    image_keys = list_s3_files(BUCKET_NAME, IMAGE_ROOT_PATH)  # Ensure this function returns the list of keys

    # Split the keys into training and validation sets
    train_keys, val_keys = train_test_split(image_keys, test_size=0.2, random_state=42)

    # Initialize the datasets
    train_dataset = ImageDataset(s3_client, BUCKET_NAME, train_keys, TRANSFORM)
    val_dataset = ImageDataset(s3_client, BUCKET_NAME, val_keys, TRANSFORM)

    # Initialize the DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8)

    print("Starting the training process.")
    try:
        model, optimizer, scheduler, start_epoch, start_batch, best_loss = load_latest_checkpoint(model, optimizer, scheduler)
        train_loop(start_epoch, start_batch, best_loss, model, optimizer, scheduler, train_loader, val_loader, criterion)
    except Exception as e:
        print(f"Error occurred during training: {e}. Exiting...")


  from .autonotebook import tqdm as notebook_tqdm
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpatrickallencooper[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
cat: /sys/module/amdgpu/initstate: No such file or directory
ERROR:root:Driver not initialized (amdgpu not found in modules)


The directory does not exist


Extracting frames:   0%|          | 0/1705 [00:00<?, ?it/s]ffmpeg version 4.4.2 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11.3.0 (conda-forge gcc 11.3.0-19)
  configuration: --prefix=/home/conda/feedstock_root/build_artifacts/ffmpeg_1671040255947/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac --cc=/home/conda/feedstock_root/build_artifacts/ffmpeg_1671040255947/_build_env/bin/x86_64-conda-linux-gnu-cc --cxx=/home/conda/feedstock_root/build_artifacts/ffmpeg_1671040255947/_build_env/bin/x86_64-conda-linux-gnu-c++ --nm=/home/conda/feedstock_root/build_artifacts/ffmpeg_1671040255947/_build_env/bin/x86_64-conda-linux-gnu-nm --ar=/home/conda/feedstock_root/build_artifacts/ffmpeg_1671040255947/_build_env/bin/x86_64-conda-linux-gnu-ar --disable-doc --disable-openssl --enable-avresample --enable-demuxer=dash --enable-hardcoded-table

Error: ffmpeg error (see stderr output for detail)

## Fathomnet Fine Tuning

In [None]:
# Define the custom dataset class for handling FathomNet data
class FathomNetDataset(Dataset):
    def __init__(self, fathomnet_root_path, concepts, transform=None):
        self.transform = transform
        self.images_info = []
        self.image_dir = fathomnet_root_path
        self.concepts = concepts
        self.concept_to_index = {concept: i for i, concept in enumerate(concepts)}

        print("Number of classes in set: " + str(len(concepts)))

        # Fetch image data for each concept and save the information
        for concept in concepts:
            try:
                images_info = images.find_by_concept(concept)
                self.images_info.extend(images_info)
            except ValueError as ve:
                print(f"Error fetching image data for concept {concept}: {ve}")
                continue

        # Sort images info to ensure consistent order across different runs
        self.images_info.sort(key=lambda x: x.uuid)

        # Create directory if it doesn't exist
        os.makedirs(self.image_dir, exist_ok=True)

        # Download images for each image info and save it to disk
        for image_info in tqdm(self.images_info, desc="Downloading images", unit="image"):
          image_url = image_info.url
          image_path = os.path.join(self.image_dir, f"{image_info.uuid}.jpg")

          # Download only if image doesn't already exist
          if not os.path.exists(image_path):
              try:
                  image_data = requests.get(image_url).content
                  with open(image_path, 'wb') as handler:
                      handler.write(image_data)
              except ValueError as ve:
                  print(f"Error downloading image from {image_url}: {ve}")
                  continue

    # Get the number of images in the dataset
    def __len__(self):
        return len(self.images_info)

    # Fetch an image and its label vector by index
    def __getitem__(self, idx):
      try:
          image_info = self.images_info[idx]
          image_path = os.path.join(self.image_dir, f"{image_info.uuid}.jpg")
          image = Image.open(image_path).convert('RGB')

          # Create label vector
          labels_vector = torch.zeros(len(self.concepts))
          for box in image_info.boundingBoxes:
            if box.concept in self.concept_to_index:
              labels_vector[self.concept_to_index[box.concept]] = 1

          # Apply transformations if any
          if self.transform:
            image = self.transform(image)

          return image, labels_vector
      except (IOError, OSError):
          print(f"Error reading image {image_path}. Skipping.")
          return None, None

def collate_fn(batch):
    # Filter out the (None, None) entries from the batch
    batch = [(image, label) for image, label in batch if image is not None and label is not None]

    # If there are no valid items left, return (None, None)
    if len(batch) == 0:
        return None, None

    # Extract and stack the images and labels
    images = torch.stack([item[0] for item in batch])
    labels_vector = torch.stack([item[1] for item in batch])

    return images, labels_vector

def load_and_train_model(model_root_path, old_model_path, fathomnet_root_path):
    # Define a transformation that resizes images to 224x224 pixels and then converts them to tensors
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    # Find the concepts for the bounding boxes
    concepts = boundingboxes.find_concepts()

    # Create a dataset with the given concepts and the defined transform
    dataset = FathomNetDataset(fathomnet_root_path, concepts, transform=transform)

    # Calculate the sizes for the training and validation datasets
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size

    # Set a seed for the random number generator
    torch.manual_seed(0)

    # Split the dataset into training and validation subsets
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    # Create data loaders for the training and validation datasets with batch size of 16
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)

    # Load the pre-trained Vision Transformer model and replace the classifier with a new one with 4 classes
    model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-224')
    model.classifier = nn.Linear(model.config.hidden_size, 4)

    # Unfreeze all layers for training
    for param in model.parameters():
        param.requires_grad = True

    # Load the pre-trained model parameters for further training
    model.load_state_dict(torch.load(old_model_path))
    print("Loaded the d2 model parameters for further training")

    # Replace the classifier again, this time with the number of concept classes
    model.classifier = nn.Linear(model.config.hidden_size, len(concepts))

    # Move the model to the GPU if available
    model = model.to(device)

    # Define the optimizer as Adam
    optimizer = optim.Adam(model.parameters())

    # Define the number of training epochs and the patience for early stopping
    num_epochs = 1
    patience = 2
    no_improve_epoch = 0

    # Frequency for saving the model
    save_freq = 1000

    # Replace the StepLR with OneCycleLR
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=len(train_loader), epochs=num_epochs)

    # Define a folder to store checkpoints
    checkpoint_folder = os.path.join(model_root_path, 'fn_checkpoints')

    # Make sure the checkpoint folder exists
    os.makedirs(checkpoint_folder, exist_ok=True)

    # Load the latest checkpoint if it exists
    checkpoints = sorted(glob.glob(os.path.join(checkpoint_folder, "*.pth")))

    # Define a function to load the latest checkpoint
    def load_latest_checkpoint():
        checkpoints = glob.glob(os.path.join(checkpoint_folder, "*.pth"))
        checkpoints.sort(key=lambda x: [int(num) for num in re.findall(r'\d+', x)], reverse=True) # Sorting based on epoch and batch number

        if checkpoints:
            latest_checkpoint_path = checkpoints[0]
            checkpoint = torch.load(latest_checkpoint_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch']
            start_batch = checkpoint['batch']
            best_loss = checkpoint['best_loss']
            print(f"Loaded Checkpoint from {latest_checkpoint_path}!!")
            return start_epoch, start_batch, best_loss
        else:
            print("No Checkpoint found!!")
            return 0, 0, np.inf


    # Define the loss function as binary cross-entropy with logits
    criterion = nn.BCEWithLogitsLoss()

    # Ensure the model is in the correct device
    model.to(device)

    # Define a function for the training loop
    def train_loop(start_epoch, start_batch, best_loss):
        total_batches = len(train_loader)  # Total number of batches in one epoch
        for epoch in range(start_epoch, num_epochs):
            print(f'Starting epoch {epoch + 1}/{num_epochs}')
            running_loss = 0.0
            model.train()

            for batch_idx, (images, labels_vector) in enumerate(train_loader, start=start_batch):
                if images is None or labels_vector is None:
                    print("Terminating batch due to image or label vector read error.")
                    break

                images = images.to(device)
                labels_vector = labels_vector.to(device)

                optimizer.zero_grad()

                outputs = model(images)
                loss = criterion(outputs.logits, labels_vector)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * images.size(0)
                wandb.log({"fn_epoch": epoch, "fn_loss": loss.item()})

                # Print epoch progress
                progress = (batch_idx + 1) / total_batches * 100
                print(f"Epoch {epoch + 1} Progress: {progress:.2f}%")

                if (batch_idx + 1) % save_freq == 0:
                    checkpoint_path = os.path.join(checkpoint_folder, f'fn_checkpoint_{epoch + 1}_{batch_idx + 1}.pth')
                    torch.save({
                        'epoch': epoch,
                        'batch': batch_idx,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict(),
                        'loss': loss,
                        'best_loss': best_loss,
                    }, checkpoint_path)
                    print(f'Saved model checkpoint at {checkpoint_path}')

            epoch_loss = running_loss / len(train_loader.dataset)

            if epoch_loss < best_loss:
                best_loss = epoch_loss
                print(f'New best loss: {best_loss}')
                no_improve_epoch = 0  # Reset patience
            else:
                no_improve_epoch += 1

            if no_improve_epoch >= patience:
                print(f'Early stopping after {patience} epochs without improvement.')
                break

    # Load the latest checkpoint and start/resume training
    start_epoch, start_batch, best_loss = load_latest_checkpoint()
    train_loop(start_epoch, start_batch, best_loss)

final_model_path = os.path.join(model_root_path, 'fn_trained_model.pth')

if os.path.exists(final_model_path):
    print("Fully trained model already exists. Skipping training.")
else:
    old_model_path = os.path.join(model_root_path, 'd2_fine_tuned_model.pt')
    load_and_train_model(model_root_path, old_model_path, fathomnet_root_path)
    torch.save(model.state_dict(), final_model_path)
