# LLaVA Fine-tuning on Sydney Fish Dataset (A100 Optimized)

This notebook implements:
1. Stratified sampling across fish species
2. A100-optimized training parameters
3. Comprehensive testing and evaluation
4. Visualization of correct and incorrect predictions

In [None]:
import torch
from transformers import BitsAndBytesConfig, LlavaNextForConditionalGeneration, AutoProcessor
import lightning as L
from torch.utils.data import DataLoader
import re
from nltk import edit_distance
import numpy as np
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from huggingface_hub import notebook_login, HfApi
from Sydney_Fish_Dataset_Stratified import StratifiedSydneyFishDataset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import json
from IPython.display import display
from PIL import Image

# Enable tensor cores for better performance
torch.set_float32_matmul_precision('high')

# LLaVA Fine-tuning on Sydney Fish Dataset (A100 Optimized)

This notebook implements:
1. Stratified sampling across fish species
2. A100-optimized training parameters
3. Comprehensive testing and evaluation
4. Visualization of correct and incorrect predictions

In [None]:
import torch
from transformers import BitsAndBytesConfig, LlavaNextForConditionalGeneration, AutoProcessor
import lightning as L
from torch.utils.data import DataLoader
import re
from nltk import edit_distance
import numpy as np
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from huggingface_hub import notebook_login, HfApi
from Sydney_Fish_Dataset_Stratified import StratifiedSydneyFishDataset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import json
from IPython.display import display
from PIL import Image

# Enable tensor cores for better performance
torch.set_float32_matmul_precision('high')

In [None]:
# Login to Hugging Face
notebook_login()

In [None]:
# Check GPU and memory
if torch.cuda.is_available():
    print("PyTorch is connected to GPU.")
    print(f"GPU Device Name: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs available: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("PyTorch is not connected to GPU.")

In [None]:
# Configuration
MAX_LENGTH = 256  # Can be larger on A100
MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf"
REPO_ID = "YOUR_HUGGINGFACE_USERNAME/llava-v1.6-mistral-7b-sydneyfish-a100"

In [None]:
# Load processor and model
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

model = LlavaNextForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    device_map="auto"
)

In [None]:
# Apply PEFT with A100-optimized settings
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    
    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

lora_config = LoraConfig(
    r=16,  # Increased for A100
    lora_alpha=32,  # Increased for A100
    lora_dropout=0.1,
    target_modules=find_all_linear_names(model),
    init_lora_weights="gaussian",
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

In [None]:
# Load datasets with stratified sampling
train_dataset = StratifiedSydneyFishDataset(split="train", seed=42)
val_dataset = StratifiedSydneyFishDataset(split="validation", seed=42)
test_dataset = StratifiedSydneyFishDataset(split="test", seed=42)

print("\nDataset Statistics:")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

print("\nTraining Set Species Distribution:")
for species, count in sorted(train_dataset.get_species_distribution().items()):
    print(f"{species}: {count} samples")

In [None]:
class LlavaModelPLModule(L.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.batch_size = config.get("batch_size")
        self.test_predictions = []
        self.test_targets = []
        self.test_images = []

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, pixel_values, image_sizes, labels = batch
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            labels=labels
        )
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, pixel_values, image_sizes, labels = batch
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            labels=labels
        )
        loss = outputs.loss
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, pixel_values, image_sizes, labels = batch
        
        # Store original images for visualization
        self.test_images.extend(pixel_values.cpu().numpy())
        
        # Generate predictions
        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            max_new_tokens=MAX_LENGTH
        )
        
        predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)
        targets = self.processor.batch_decode(labels, skip_special_tokens=True)
        
        self.test_predictions.extend(predictions)
        self.test_targets.extend(targets)
        
        return {"predictions": predictions, "targets": targets}

    def on_test_epoch_end(self):
        # Extract species names
        pred_species = []
        target_species = []
        correct_indices = []
        incorrect_indices = []
        
        for i, (pred, target) in enumerate(zip(self.test_predictions, self.test_targets)):
            try:
                pred_json = json.loads(pred)
                pred_name = pred_json["species"]["name"]
                pred_species.append(pred_name)
            except:
                pred_species.append("ERROR")
                
            try:
                target_json = json.loads(target)
                target_name = target_json["species"]["name"]
                target_species.append(target_name)
            except:
                target_species.append("ERROR")
            
            if pred_species[-1] == target_species[-1]:
                correct_indices.append(i)
            else:
                incorrect_indices.append(i)
        
        # Calculate and display metrics
        accuracy = len(correct_indices) / len(self.test_predictions)
        print(f"\nTest Accuracy: {accuracy * 100:.2f}%")
        
        print("\nClassification Report:")
        print(classification_report(target_species, pred_species))
        
        # Plot confusion matrix
        plt.figure(figsize=(15, 10))
        cm = confusion_matrix(target_species, pred_species)
        sns.heatmap(cm, annot=True, fmt='d', 
                    xticklabels=sorted(set(target_species)), 
                    yticklabels=sorted(set(target_species)))
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.xticks(rotation=45)
        plt.yticks(rotation=45)
        plt.tight_layout()
        plt.show()
        
        # Display example predictions
        print("\nExample Correct Predictions:")
        for idx in correct_indices[:5]:
            print(f"\nTrue: {target_species[idx]}")
            print(f"Predicted: {pred_species[idx]}")
            img = self.test_images[idx]
            plt.imshow(np.transpose(img, (1, 2, 0)))
            plt.axis('off')
            plt.show()
            
        print("\nExample Incorrect Predictions:")
        for idx in incorrect_indices[:5]:
            print(f"\nTrue: {target_species[idx]}")
            print(f"Predicted: {pred_species[idx]}")
            img = self.test_images[idx]
            plt.imshow(np.transpose(img, (1, 2, 0)))
            plt.axis('off')
            plt.show()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))
        return optimizer

In [None]:
# A100-optimized training configuration
config = {
    "max_epochs": 10,
    "check_val_every_n_epoch": 1,
    "gradient_clip_val": 1.0,
    "accumulate_grad_batches": 4,  # Reduced for A100
    "lr": 2e-4,
    "batch_size": 8,  # Increased for A100
    "num_nodes": 1,
    "warmup_steps": 100,
    "result_path": "./result",
    "verbose": True,
    "num_workers": 8  # Increased for A100
}

model_module = LlavaModelPLModule(config, processor, model)

In [None]:
# Callbacks
class PushToHubCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
        pl_module.model.push_to_hub(
            REPO_ID,
            commit_message=f"Training in progress, epoch {trainer.current_epoch}"
        )

    def on_train_end(self, trainer, pl_module):
        print(f"Pushing model to the hub after training")
        pl_module.processor.push_to_hub(
            REPO_ID,
            commit_message=f"Training done"
        )
        pl_module.model.push_to_hub(
            REPO_ID,
            commit_message=f"Training done"
        )

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=True,
    mode="min"
)

In [None]:
# Initialize trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=config.get("max_epochs"),
    accumulate_grad_batches=config.get("accumulate_grad_batches"),
    check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
    gradient_clip_val=config.get("gradient_clip_val"),
    precision="16-mixed",
    callbacks=[
        PushToHubCallback(),
        early_stop_callback
    ],
    strategy="auto",
    enable_progress_bar=True,
    enable_model_summary=True
)

In [None]:
# Train the model
trainer.fit(model_module)

In [None]:
# Test the model and visualize results
trainer.test(model_module)

In [None]:
# Login to Hugging Face
notebook_login()

In [None]:
# Check GPU and memory
if torch.cuda.is_available():
    print("PyTorch is connected to GPU.")
    print(f"GPU Device Name: {torch.cuda.get_device_name(0)}")
    print(f"Number of GPUs available: {torch.cuda.device_count()}")
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
    print("PyTorch is not connected to GPU.")

In [None]:
# Configuration
MAX_LENGTH = 256  # Can be larger on A100
MODEL_ID = "llava-hf/llava-v1.6-mistral-7b-hf"
REPO_ID = "YOUR_HUGGINGFACE_USERNAME/llava-v1.6-mistral-7b-sydneyfish-a100"

In [None]:
# Load processor and model
processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

model = LlavaNextForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    quantization_config=bnb_config,
    device_map="auto"
)

In [None]:
# Apply PEFT with A100-optimized settings
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    
    if 'lm_head' in lora_module_names:
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

lora_config = LoraConfig(
    r=16,  # Increased for A100
    lora_alpha=32,  # Increased for A100
    lora_dropout=0.1,
    target_modules=find_all_linear_names(model),
    init_lora_weights="gaussian",
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)

In [None]:
# Load datasets with stratified sampling
train_dataset = StratifiedSydneyFishDataset(split="train", seed=42)
val_dataset = StratifiedSydneyFishDataset(split="validation", seed=42)
test_dataset = StratifiedSydneyFishDataset(split="test", seed=42)

print("\nDataset Statistics:")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

print("\nTraining Set Species Distribution:")
for species, count in sorted(train_dataset.get_species_distribution().items()):
    print(f"{species}: {count} samples")

In [None]:
class LlavaModelPLModule(L.LightningModule):
    def __init__(self, config, processor, model):
        super().__init__()
        self.config = config
        self.processor = processor
        self.model = model
        self.batch_size = config.get("batch_size")
        self.test_predictions = []
        self.test_targets = []
        self.test_images = []

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, pixel_values, image_sizes, labels = batch
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            labels=labels
        )
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, pixel_values, image_sizes, labels = batch
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            labels=labels
        )
        loss = outputs.loss
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids, attention_mask, pixel_values, image_sizes, labels = batch
        
        # Store original images for visualization
        self.test_images.extend(pixel_values.cpu().numpy())
        
        # Generate predictions
        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            image_sizes=image_sizes,
            max_new_tokens=MAX_LENGTH
        )
        
        predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)
        targets = self.processor.batch_decode(labels, skip_special_tokens=True)
        
        self.test_predictions.extend(predictions)
        self.test_targets.extend(targets)
        
        return {"predictions": predictions, "targets": targets}

    def on_test_epoch_end(self):
        # Extract species names
        pred_species = []
        target_species = []
        correct_indices = []
        incorrect_indices = []
        
        for i, (pred, target) in enumerate(zip(self.test_predictions, self.test_targets)):
            try:
                pred_json = json.loads(pred)
                pred_name = pred_json["species"]["name"]
                pred_species.append(pred_name)
            except:
                pred_species.append("ERROR")
                
            try:
                target_json = json.loads(target)
                target_name = target_json["species"]["name"]
                target_species.append(target_name)
            except:
                target_species.append("ERROR")
            
            if pred_species[-1] == target_species[-1]:
                correct_indices.append(i)
            else:
                incorrect_indices.append(i)
        
        # Calculate and display metrics
        accuracy = len(correct_indices) / len(self.test_predictions)
        print(f"\nTest Accuracy: {accuracy * 100:.2f}%")
        
        print("\nClassification Report:")
        print(classification_report(target_species, pred_species))
        
        # Plot confusion matrix
        plt.figure(figsize=(15, 10))
        cm = confusion_matrix(target_species, pred_species)
        sns.heatmap(cm, annot=True, fmt='d', 
                    xticklabels=sorted(set(target_species)), 
                    yticklabels=sorted(set(target_species)))
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.xticks(rotation=45)
        plt.yticks(rotation=45)
        plt.tight_layout()
        plt.show()
        
        # Display example predictions
        print("\nExample Correct Predictions:")
        for idx in correct_indices[:5]:
            print(f"\nTrue: {target_species[idx]}")
            print(f"Predicted: {pred_species[idx]}")
            img = self.test_images[idx]
            plt.imshow(np.transpose(img, (1, 2, 0)))
            plt.axis('off')
            plt.show()
            
        print("\nExample Incorrect Predictions:")
        for idx in incorrect_indices[:5]:
            print(f"\nTrue: {target_species[idx]}")
            print(f"Predicted: {pred_species[idx]}")
            img = self.test_images[idx]
            plt.imshow(np.transpose(img, (1, 2, 0)))
            plt.axis('off')
            plt.show()

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get("lr"))
        return optimizer

In [None]:
# A100-optimized training configuration
config = {
    "max_epochs": 10,
    "check_val_every_n_epoch": 1,
    "gradient_clip_val": 1.0,
    "accumulate_grad_batches": 4,  # Reduced for A100
    "lr": 2e-4,
    "batch_size": 8,  # Increased for A100
    "num_nodes": 1,
    "warmup_steps": 100,
    "result_path": "./result",
    "verbose": True,
    "num_workers": 8  # Increased for A100
}

model_module = LlavaModelPLModule(config, processor, model)

In [None]:
# Callbacks
class PushToHubCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
        pl_module.model.push_to_hub(
            REPO_ID,
            commit_message=f"Training in progress, epoch {trainer.current_epoch}"
        )

    def on_train_end(self, trainer, pl_module):
        print(f"Pushing model to the hub after training")
        pl_module.processor.push_to_hub(
            REPO_ID,
            commit_message=f"Training done"
        )
        pl_module.model.push_to_hub(
            REPO_ID,
            commit_message=f"Training done"
        )

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=True,
    mode="min"
)

In [None]:
# Initialize trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=[0],
    max_epochs=config.get("max_epochs"),
    accumulate_grad_batches=config.get("accumulate_grad_batches"),
    check_val_every_n_epoch=config.get("check_val_every_n_epoch"),
    gradient_clip_val=config.get("gradient_clip_val"),
    precision="16-mixed",
    callbacks=[
        PushToHubCallback(),
        early_stop_callback
    ],
    strategy="auto",
    enable_progress_bar=True,
    enable_model_summary=True
)

In [None]:
# Train the model
trainer.fit(model_module)

In [None]:
# Test the model and visualize results
trainer.test(model_module)