In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os
import time
from PIL import Image
import pickle

In [None]:
# Swin Transformer block
class SwinTransformerBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_heads=1):
        super(SwinTransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=in_channels, num_heads=num_heads)
        self.linear1 = nn.Linear(in_channels, out_channels)
        self.linear2 = nn.Linear(out_channels, out_channels)
        self.norm1 = nn.LayerNorm(out_channels)
        self.norm2 = nn.LayerNorm(out_channels)

    def forward(self, x):
        x = x.view(x.size(0), -1, x.size(1)) 
        attn_output, _ = self.attention(x, x, x)
        x = self.linear1(attn_output)
        x = self.norm1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = self.norm2(x)
        x = F.relu(x)
        x = x.view(x.size(0), x.size(2), int(x.size(1)**0.5), int(x.size(1)**0.5)) 
        return x

# ConvNeXt block
class ConvNeXtBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvNeXtBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm = nn.LayerNorm(out_channels)
        self.linear = nn.Linear(out_channels, out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = x.permute(0, 2, 3, 1) 
        B, H, W, C = x.shape
        x = x.contiguous().view(B * H * W, C) 
        x = self.norm(x)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2) 
        x = F.relu(x)
        return x

# Dataset class
class GarbageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.image_paths = []
        self.labels = []

        for idx, cls in enumerate(self.classes):
            class_dir = os.path.join(root_dir, cls)
            for img in os.listdir(class_dir):
                if img.endswith(".jpg") or img.endswith(".png"):
                    self.image_paths.append(os.path.join(class_dir, img))
                    self.labels.append(idx)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Model
class FusionModel(nn.Module):
    def __init__(self):
        super(FusionModel, self).__init__()
        self.swin_transformer_block1 = SwinTransformerBlock(in_channels=3, out_channels=64)
        self.convnext_block1 = ConvNeXtBlock(in_channels=3, out_channels=64)
        self.spatial_attention_mechanism = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1)
        self.classifier = nn.Linear(224*224, 12)

    def forward(self, x):
        swin_output = self.swin_transformer_block1(x)
        convnext_output = self.convnext_block1(x)
        combined_output = torch.cat((swin_output, convnext_output), dim=1)
        attention_output = self.spatial_attention_mechanism(combined_output)
        attention_output = attention_output.view(attention_output.size(0), -1)
        output = self.classifier(attention_output)
        return output

In [None]:
# Data transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load datasets
dataset = datasets.ImageFolder(root='garbage_classification', transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Model Initialization
model = FusionModel()
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
# Accuracy function
def accuracy(outputs, labels):
    _, preds = torch.max(outputs, 1)
    return torch.sum(preds == labels).item() / labels.size(0)

# Model Training
for epoch in range(5):
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    start_time = time.time()

    for i, batch in enumerate(train_loader):
        images, labels = batch
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        running_corrects += accuracy(outputs, labels) * images.size(0)
        total_samples += images.size(0)
        print(f"Epoch [{epoch+1}/5], Batch [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}, Accuracy: {accuracy(outputs, labels) * 100:.2f}%")

    epoch_loss = running_loss / total_samples
    epoch_acc = running_corrects / total_samples

    end_time = time.time()
    epoch_duration = end_time - start_time

    print(f"Epoch [{epoch+1}/5] completed in {epoch_duration:.2f} seconds, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc * 100:.2f}%")

    scheduler.step()

In [None]:
# Model Saving
model_path = "fusion_model_offline.pkl"
with open(model_path, 'wb') as f:
    pickle.dump(model, f)

print(f"Model saved to {model_path}")