In [1]:
import torch
import torch.nn as nn
import timm
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import torch.optim as optim
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
# Define the ChannelShuffle module
class ChannelShuffle(nn.Module):
    def __init__(self, groups):
        super(ChannelShuffle, self).__init__()
        self.groups = groups

    def forward(self, x):
        batch_size, num_channels, height, width = x.size()
        channels_per_group = num_channels // self.groups

        # Reshape
        x = x.view(batch_size, self.groups, channels_per_group, height, width)
        x = x.transpose(1, 2).contiguous()

        # Flatten
        x = x.view(batch_size, -1, height, width)
        return x

In [3]:
# Define the Coordinate Attention Module (CoordAtt)
class CoordAtt(nn.Module):
    def __init__(self, in_channels, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        
        mip = max(8, in_channels // reduction)
        self.conv1 = nn.Conv2d(in_channels, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = nn.ReLU(inplace=True)
        
        self.conv_h = nn.Conv2d(mip, in_channels, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, in_channels, kernel_size=1, stride=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        
        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        
        a_h = self.sigmoid(self.conv_h(x_h))
        a_w = self.sigmoid(self.conv_w(x_w))
        
        out = identity * a_h * a_w
        return out


In [4]:
# Define the RCSA module
class RCSA(nn.Module):
    def __init__(self, in_channels, reduction=16, groups=4):
        super(RCSA, self).__init__()
        self.groups = groups
        self.channel_shuffle = ChannelShuffle(groups)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)
        self.coord_att = CoordAtt(in_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        residual = x
        # Channel Shuffle
        x = self.channel_shuffle(x)
        # Squeeze and Excitation
        x = self.avg_pool(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        # Coordinate Attention
        x = self.coord_att(x)
        x = self.sigmoid(x)
        # Scale
        x = residual * x
        return x


In [5]:
# Define a custom CNN with RCSA and Coordinate Attention
class CustomCNN(nn.Module):
    def __init__(self, num_classes):
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.rcsa1 = RCSA(64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.rcsa2 = RCSA(128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.rcsa3 = RCSA(256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(256 * 28 * 28, 128)  # Adjust input size based on your image dimensions after pooling
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.rcsa1(x)
        x = self.pool1(x)
        
        x = self.relu(self.conv2(x))
        x = self.rcsa2(x)
        x = self.pool2(x)
        
        x = self.relu(self.conv3(x))
        x = self.rcsa3(x)
        x = self.pool3(x)
        
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
class CombinedModel(nn.Module):
    def __init__(self, crossvit_name='crossvit_15_240', num_classes=10):
        super(CombinedModel, self).__init__()
        self.crossvit_model = timm.create_model(crossvit_name, pretrained=True)

        # Assuming we use the second head (index 1) as it has higher in_features
        crossvit_output_dim = self.crossvit_model.head[1].in_features
        
        # Remove the classifier head from CrossViT
        self.crossvit_model.head = nn.ModuleList([nn.Identity(), nn.Identity()])

        self.custom_cnn = CustomCNN(num_classes)
        cnn_output_dim = 256 * 28 * 28
        self.fc1 = nn.Linear(crossvit_output_dim + cnn_output_dim, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        crossvit_out = self.crossvit_model(x)
        cnn_out = self.custom_cnn(x)
        combined_out = torch.cat((crossvit_out, cnn_out), dim=1)
        combined_out = F.relu(self.fc1(combined_out))
        combined_out = self.fc2(combined_out)
        return combined_out


In [7]:

# Function to check if image is valid
def is_valid_image(image_path):
    try:
        # Attempt to open and validate the image
        Image.open(image_path).convert('RGB')
        return True
    except (FileNotFoundError, OSError, ValueError):
        return False

# Custom dataset class with error handling
class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform=transform)

    def __getitem__(self, index):
        # Get image path and label from original ImageFolder
        path, target = self.samples[index]

        # Check if image is valid
        if not is_valid_image(path):
            # Handle invalid image (e.g., return a placeholder image or skip)
            # For simplicity, returning a placeholder image and an out-of-bounds target
            invalid_image = torch.zeros((3, 224, 224))  # Placeholder image tensor
            return invalid_image, len(self.classes)  # Return an out-of-bounds target

        # Load and transform the image
        image = self.loader(path)
        if self.transform is not None:
            image = self.transform(image)

        return image, target

# Dataset path
dataset_path = 'C:/Users/aryan/Desktop/WORK/Summer Research Internship/Pear/leaves'  # Change this to the actual dataset path


In [8]:
# Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the custom dataset
dataset = CustomImageFolder(dataset_path, transform=transform)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Define the DataLoader
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


# Instantiate the model
combined_model = CombinedModel(num_classes=len(dataset.classes))

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=0.001)

In [None]:
# Training loop
num_epochs = 10
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(num_epochs):
    combined_model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = combined_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()
        
        if i % 100 == 99:    # Print every 100 mini-batches
            print(f'Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
            running_loss = 0.0
    
    train_losses.append(running_loss / len(train_loader))
    train_accuracies.append(100 * correct_train / total_train)
    
    combined_model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = combined_model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
    
    val_losses.append(val_loss / len(val_loader))
    val_accuracies.append(100 * correct_val / total_val)
    
    print(f'Epoch [{epoch + 1}/{num_epochs}], Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.2f}%, Val Acc: {val_accuracies[-1]:.2f}%')

print('Finished Training')

In [None]:
# Plotting the results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.show()