In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm


In [22]:
# Transformations for train and validation datasets
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load train and validation datasets
train_dataset = datasets.ImageFolder(root="data_waste/train", transform=train_transform)
valid_dataset = datasets.ImageFolder(root="data_waste/valid", transform=valid_transform)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

# Number of classes
num_classes = len(train_dataset.classes)
print(f"Number of classes: {num_classes}")


Number of classes: 25


In [23]:
class Q2L(nn.Module):
    def __init__(self, num_labels, backbone="vit_base_patch16_224", pretrained=True):
        super(Q2L, self).__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained, num_classes=0)
        self.query_embedding = nn.Embedding(num_labels, self.backbone.num_features)
        self.classifier = nn.Linear(self.backbone.num_features, num_labels)

    def forward(self, x):
        # Extract features from the backbone
        features = self.backbone(x)
        
        # Add query embeddings
        query_embeddings = self.query_embedding.weight  # Shape: [num_labels, embedding_dim]
        features = features.unsqueeze(1) + query_embeddings  # Shape: [batch_size, num_labels, embedding_dim]
        
        # Classify
        logits = self.classifier(features).squeeze(1)
        return logits


In [26]:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Q2L(num_labels=num_classes).to(device)

# Loss and optimizer
criterion = nn.BCEWithLogitsLoss()  # For multi-label classification
optimizer = optim.AdamW(model.parameters(), lr=1e-4)


In [27]:
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    for images, labels in train_loader:
        # Convert labels to one-hot encoding for multi-label classification
        labels_one_hot = torch.zeros((labels.size(0), num_classes), device=device)
        labels_one_hot.scatter_(1, labels.unsqueeze(1), 1)

        images, labels_one_hot = images.to(device), labels_one_hot.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels_one_hot)
        train_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {train_loss / len(train_loader):.4f}")


ValueError: Target size (torch.Size([16, 25])) must be the same as input size (torch.Size([16, 25, 25]))