# Imports

In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import ViTModel, ViTConfig
from PIL import Image
import cv2
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Device

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Constants

In [3]:
BATCH_SIZE = 32
EPOCHS = 25
LEARNING_RATE = 0.001
NUM_CLASSES = 2
DATA_DIR = "/kaggle/input/dataset/dataset"
MODEL_SAVE_PATH = "multi_attribute_fish_model_novel.pth"
VIT_CONFIG_PATH = "/kaggle/input/vitfiles/config.json"
VIT_MODEL_PATH = "/kaggle/input/vitfiles/pytorch_model.bin"
RESNET18_WEIGHTS_PATH = "/kaggle/input/pre-trained-resnet/resnet18-f37072fd.pth"

# Early stopping
best_val_loss = float('inf')
patience = 5
counter = 0

# Dataset Class

In [4]:
class FishDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = ["fresh", "non_fresh"]
        self.image_paths = []
        self.labels = []

        # Load image paths and labels
        for label_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(data_dir, class_name)
            for attribute in ["eyes", "gills"]:  # Only eyes and gills
                attribute_dir = os.path.join(class_dir, attribute)
                for img_name in os.listdir(attribute_dir):
                    self.image_paths.append((os.path.join(attribute_dir, img_name), attribute, label_idx))
                    self.labels.append(label_idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path, attribute, label = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        
        # Apply domain-specific preprocessing
        image = self.preprocess_image(image, attribute)
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        return image, label
        
    def preprocess_image(self, image, attribute):
        image = np.array(image)
        
        if attribute == "eyes":
            # Reddish color detection for non-fresh eyes
            hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
            lower_red = np.array([0, 50, 50])  # Lower range for red
            upper_red = np.array([10, 255, 255])  # Upper range for red
            red_mask = cv2.inRange(hsv_image, lower_red, upper_red)
            
            # Glitter/reflectivity detection for fresh eyes
            gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(gray_image, 100, 200)  # Detect edges
            
            # Combine masks
            combined_mask = cv2.bitwise_or(red_mask, edges)
            image = cv2.bitwise_and(image, image, mask=combined_mask)
        
        elif attribute == "gills":
            # Enhance color contrast for gills
            lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_image)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            l_channel = clahe.apply(l_channel)
            lab_image = cv2.merge((l_channel, a_channel, b_channel))
            image = cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
        
        return Image.fromarray(image)

# Preprocessing Transformations with Augmentation

In [5]:
# Preprocessing Transformations with Augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),  # Randomly flip images horizontally
    transforms.RandomRotation(10),     # Randomly rotate images by ±10 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Randomly adjust brightness, contrast, and saturation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
])

# Loading Datasets

In [6]:
train_dataset = FishDataset(data_dir=os.path.join(DATA_DIR, "train"), transform=transform)
valid_dataset = FishDataset(data_dir=os.path.join(DATA_DIR, "valid"), transform=transform)
test_dataset = FishDataset(data_dir=os.path.join(DATA_DIR, "test"), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# Model Definition

In [7]:
class MultiAttributeFishModel(nn.Module):
    def __init__(self, num_classes=2):
        super(MultiAttributeFishModel, self).__init__()
        
        # Pre-trained CNNs for feature extraction
        self.eye_cnn = models.resnet18(pretrained=False)
        self.eye_cnn.load_state_dict(torch.load(RESNET18_WEIGHTS_PATH))
        
        self.gill_cnn = models.resnet18(pretrained=False)
        self.gill_cnn.load_state_dict(torch.load(RESNET18_WEIGHTS_PATH))
        
        # Modify CNNs for single-channel output
        self.eye_cnn.fc = nn.Linear(self.eye_cnn.fc.in_features, 128)
        self.gill_cnn.fc = nn.Linear(self.gill_cnn.fc.in_features, 128)
        
        # Vision Transformer for global context
        vit_config = ViTConfig.from_pretrained(VIT_CONFIG_PATH)
        self.vit = ViTModel(vit_config)

        # Load the state dictionary and remove the 'vit.' prefix
        state_dict = torch.load(VIT_MODEL_PATH)
        state_dict = {k.replace("vit.", ""): v for k, v in state_dict.items()}

        # Remove classifier-related keys
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("classifier.")}

        # Load the modified state dictionary into the ViT model
        self.vit.load_state_dict(state_dict, strict=False)  # Set strict=False to ignore missing keys

        self.vit_fc = nn.Linear(self.vit.config.hidden_size, 128)
        
        # Attention Mechanism for Eyes and Gills
        self.eye_attention = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1),
            nn.Softmax(dim=1)
        )
        self.gill_attention = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1),
            nn.Softmax(dim=1)
        )
        
        # Weighted Fusion Layer
        self.fusion_fc = nn.Sequential(
            nn.Linear(128 * 3, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, eye_img, gill_img):
        # Extract features using CNNs
        eye_features = self.eye_cnn(eye_img)  # [batch_size, 128]
        gill_features = self.gill_cnn(gill_img)  # [batch_size, 128]
        
        # Apply Attention Mechanisms
        eye_weights = self.eye_attention(eye_features)  # [batch_size, 1]
        gill_weights = self.gill_attention(gill_features)  # [batch_size, 1]
        
        eye_features = eye_features * eye_weights  # Weighted features
        gill_features = gill_features * gill_weights  # Weighted features
        
        # Extract global context using ViT
        vit_outputs = self.vit(eye_img)  # Use eye_img as input to ViT
        vit_features = self.vit_fc(vit_outputs.last_hidden_state.mean(dim=1))  # [batch_size, 128]
        
        # Concatenate features
        combined_features = torch.cat([eye_features, gill_features, vit_features], dim=1)  # [batch_size, 128 * 3]
        
        # Final classification
        output = self.fusion_fc(combined_features)  # [batch_size, num_classes]
        return output

# Initialize Model, Loss, Optimizer and Scheduler

In [8]:
# Model. Loss, Optimizer
model = MultiAttributeFishModel(num_classes=NUM_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)  # Add weight decay

# Initialize the scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)


  self.eye_cnn.load_state_dict(torch.load(RESNET18_WEIGHTS_PATH))
  self.gill_cnn.load_state_dict(torch.load(RESNET18_WEIGHTS_PATH))
  state_dict = torch.load(VIT_MODEL_PATH)


# Training and Validation Loops

In [9]:
# Training 
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        # Forward pass
        outputs = model(images, images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{EPOCHS}], Training Loss: {running_loss/len(train_loader):.4f}")

 # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images, images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_accuracy = 100 * correct / total
    print(f"Validation Loss: {val_loss/len(valid_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%")

  # Step the scheduler
    scheduler.step(val_loss)
    
    # Early stopping logic
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print(f"Model saved to {MODEL_SAVE_PATH}")
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered!")
            break



Epoch [1/25], Training Loss: 0.6362
Validation Loss: 0.5665, Validation Accuracy: 73.91%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [2/25], Training Loss: 0.5739
Validation Loss: 0.5092, Validation Accuracy: 73.91%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [3/25], Training Loss: 0.5616
Validation Loss: 0.4798, Validation Accuracy: 78.53%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [4/25], Training Loss: 0.5574
Validation Loss: 0.7232, Validation Accuracy: 63.56%
Epoch [5/25], Training Loss: 0.5091
Validation Loss: 0.5234, Validation Accuracy: 79.30%
Epoch [6/25], Training Loss: 0.5275
Validation Loss: 0.4348, Validation Accuracy: 80.58%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [7/25], Training Loss: 0.4949
Validation Loss: 0.4789, Validation Accuracy: 79.98%
Epoch [8/25], Training Loss: 0.5143
Validation Loss: 0.4224, Validation Accuracy: 81.86%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [9/25], Training Loss:

# Testing

In [11]:

model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images, images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {100 * correct / total:.2f}%")

Test Loss: 0.2680, Test Accuracy: 90.00%


# Complete Code

In [5]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from transformers import ViTModel, ViTConfig
from PIL import Image
import cv2
import numpy as np
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 0.001
NUM_CLASSES = 2  # Fresh and Non-Fresh
DATA_DIR = "/kaggle/input/dataset/dataset"
MODEL_SAVE_PATH = "multi_attribute_fish_model_novel.pth"
# Path to the uploaded ViT files
VIT_CONFIG_PATH = "/kaggle/input/vitfiles/config.json"
VIT_MODEL_PATH = "/kaggle/input/vitfiles/pytorch_model.bin"
# Path to the uploaded weights file
RESNET18_WEIGHTS_PATH = "/kaggle/input/pre-trained-resnet/resnet18-f37072fd.pth"

# Dataset Class
class FishDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = ["fresh", "non_fresh"]
        self.image_paths = []
        self.labels = []

        # Load image paths and labels
        for label_idx, class_name in enumerate(self.classes):
            class_dir = os.path.join(data_dir, class_name)
            for attribute in ["eyes", "gills"]:  # Only eyes and gills
                attribute_dir = os.path.join(class_dir, attribute)
                for img_name in os.listdir(attribute_dir):
                    self.image_paths.append((os.path.join(attribute_dir, img_name), attribute, label_idx))
                    self.labels.append(label_idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path, attribute, label = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        
        # Apply domain-specific preprocessing
        image = self.preprocess_image(image, attribute)
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
        
        return image, label

    def preprocess_image(self, image, attribute):
        image = np.array(image)
        
        if attribute == "eyes":
            # Reddish color detection for non-fresh eyes
            hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
            lower_red = np.array([0, 50, 50])  # Lower range for red
            upper_red = np.array([10, 255, 255])  # Upper range for red
            red_mask = cv2.inRange(hsv_image, lower_red, upper_red)
            
            # Glitter/reflectivity detection for fresh eyes
            gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(gray_image, 100, 200)  # Detect edges
            
            # Combine masks
            combined_mask = cv2.bitwise_or(red_mask, edges)
            image = cv2.bitwise_and(image, image, mask=combined_mask)
        
        elif attribute == "gills":
            # Enhance color contrast for gills
            lab_image = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
            l_channel, a_channel, b_channel = cv2.split(lab_image)
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            l_channel = clahe.apply(l_channel)
            lab_image = cv2.merge((l_channel, a_channel, b_channel))
            image = cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)
        
        return Image.fromarray(image)

# Preprocessing Transformations with Augmentation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),  # Randomly flip images horizontally
    transforms.RandomRotation(10),     # Randomly rotate images by ±10 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Randomly adjust brightness, contrast, and saturation
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet normalization
])

# Load Datasets
train_dataset = FishDataset(data_dir=os.path.join(DATA_DIR, "train"), transform=transform)
valid_dataset = FishDataset(data_dir=os.path.join(DATA_DIR, "valid"), transform=transform)
test_dataset = FishDataset(data_dir=os.path.join(DATA_DIR, "test"), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# Model Definition
class MultiAttributeFishModel(nn.Module):
    def __init__(self, num_classes=2):
        super(MultiAttributeFishModel, self).__init__()
        
        # Pre-trained CNNs for feature extraction
        self.eye_cnn = models.resnet18(pretrained=False)
        self.eye_cnn.load_state_dict(torch.load(RESNET18_WEIGHTS_PATH))
        
        self.gill_cnn = models.resnet18(pretrained=False)
        self.gill_cnn.load_state_dict(torch.load(RESNET18_WEIGHTS_PATH))
        
        # Modify CNNs for single-channel output
        self.eye_cnn.fc = nn.Linear(self.eye_cnn.fc.in_features, 128)
        self.gill_cnn.fc = nn.Linear(self.gill_cnn.fc.in_features, 128)
        
        # Vision Transformer for global context
        vit_config = ViTConfig.from_pretrained(VIT_CONFIG_PATH)
        self.vit = ViTModel(vit_config)

        # Load the state dictionary and remove the 'vit.' prefix
        state_dict = torch.load(VIT_MODEL_PATH)
        state_dict = {k.replace("vit.", ""): v for k, v in state_dict.items()}

        # Remove classifier-related keys
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("classifier.")}

        # Load the modified state dictionary into the ViT model
        self.vit.load_state_dict(state_dict, strict=False)  # Set strict=False to ignore missing keys

        self.vit_fc = nn.Linear(self.vit.config.hidden_size, 128)
        
        # Attention Mechanism for Eyes and Gills
        self.eye_attention = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.5),  # Add dropout
            nn.Linear(128, 1),
            nn.Softmax(dim=1)
        )
        self.gill_attention = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.5),  # Add dropout
            nn.Linear(128, 1),
            nn.Softmax(dim=1)
        )
        
        # Weighted Fusion Layer
        self.fusion_fc = nn.Sequential(
            nn.Linear(128 * 3, 256),
            nn.ReLU(),
            nn.Dropout(0.5),  # Add dropout
            nn.Linear(256, num_classes)
        )

    def forward(self, eye_img, gill_img):
        # Extract features using CNNs
        eye_features = self.eye_cnn(eye_img)  # [batch_size, 128]
        gill_features = self.gill_cnn(gill_img)  # [batch_size, 128]
        
        # Apply Attention Mechanisms
        eye_weights = self.eye_attention(eye_features)  # [batch_size, 1]
        gill_weights = self.gill_attention(gill_features)  # [batch_size, 1]
        
        eye_features = eye_features * eye_weights  # Weighted features
        gill_features = gill_features * gill_weights  # Weighted features
        
        # Extract global context using ViT
        vit_outputs = self.vit(eye_img)  # Use eye_img as input to ViT
        vit_features = self.vit_fc(vit_outputs.last_hidden_state.mean(dim=1))  # [batch_size, 128]
        
        # Concatenate features
        combined_features = torch.cat([eye_features, gill_features, vit_features], dim=1)  # [batch_size, 128 * 3]
        
        # Final classification
        output = self.fusion_fc(combined_features)  # [batch_size, num_classes]
        return output

# Initialize Model, Loss, and Optimizer
model = MultiAttributeFishModel(num_classes=NUM_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)  # Add weight decay

# Initialize the scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

# Early stopping
best_val_loss = float('inf')
patience = 5
counter = 0

# Training and Validation Loop
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        # Forward pass
        outputs = model(images, images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{EPOCHS}], Training Loss: {running_loss/len(train_loader):.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images, images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    val_accuracy = 100 * correct / total
    print(f"Validation Loss: {val_loss/len(valid_loader):.4f}, Validation Accuracy: {val_accuracy:.2f}%")
    
    # Step the scheduler
    scheduler.step(val_loss)
    
    # Early stopping logic
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print(f"Model saved to {MODEL_SAVE_PATH}")
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered!")
            break

# Test the Model
model.eval()
test_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images, images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {100 * correct / total:.2f}%")

  self.eye_cnn.load_state_dict(torch.load(RESNET18_WEIGHTS_PATH))
  self.gill_cnn.load_state_dict(torch.load(RESNET18_WEIGHTS_PATH))
  state_dict = torch.load(VIT_MODEL_PATH)


Epoch [1/20], Training Loss: 0.6340
Validation Loss: 0.5045, Validation Accuracy: 76.05%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [2/20], Training Loss: 0.5636
Validation Loss: 1.0206, Validation Accuracy: 56.63%
Epoch [3/20], Training Loss: 0.5668
Validation Loss: 1.1902, Validation Accuracy: 71.34%
Epoch [4/20], Training Loss: 0.5358
Validation Loss: 0.4391, Validation Accuracy: 79.64%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [5/20], Training Loss: 0.4993
Validation Loss: 0.4342, Validation Accuracy: 79.64%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [6/20], Training Loss: 0.4996
Validation Loss: 0.4102, Validation Accuracy: 79.98%
Model saved to multi_attribute_fish_model_novel.pth
Epoch [7/20], Training Loss: 0.4732
Validation Loss: 0.4225, Validation Accuracy: 80.07%
Epoch [8/20], Training Loss: 0.4864
Validation Loss: 0.4235, Validation Accuracy: 80.58%
Epoch [9/20], Training Loss: 0.4787
Validation Loss: 0.6019, Validation Accuracy