In [1]:
!pip install -q timm


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import timm


In [4]:
train_tfms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_tfms = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])


In [9]:
train_ds = datasets.ImageFolder(r"C:\Users\keval\OneDrive\Desktop\final", transform=train_tfms)
val_ds   = datasets.ImageFolder(r"C:\Users\keval\OneDrive\Desktop\final", transform=val_tfms)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)

print("Classes:", train_ds.classes)
print("Train:", len(train_ds), "Val:", len(val_ds))


Classes: ['test', 'train', 'val']
Train: 7822 Val: 7822


In [11]:
class RGBOnly(nn.Module):
    def __init__(self):
        super().__init__()
        self.rgb = timm.create_model("efficientnet_b0", pretrained=True, num_classes=1)

    def forward(self, x):
        return self.rgb(x)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = RGBOnly().to(device)


In [13]:
for param in model.rgb.parameters():
    param.requires_grad = False

classifier = model.rgb.get_classifier()
for param in classifier.parameters():
    param.requires_grad = True


In [15]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)


In [17]:
def train_one_epoch(model, loader):
    model.train()
    total_loss = 0
    for x,y in loader:
        x,y = x.to(device), y.float().to(device)
        optimizer.zero_grad()
        out = model(x).squeeze()
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def validate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            preds = torch.sigmoid(model(x)).squeeze() > 0.5
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total


In [19]:
for epoch in range(5):
    loss = train_one_epoch(model, train_loader)
    acc = validate(model, val_loader)
    print(f"Epoch {epoch+1}: Loss={loss:.4f}, Val Acc={acc:.4f}")


Epoch 1: Loss=1.6383, Val Acc=0.6163
Epoch 2: Loss=0.2472, Val Acc=0.6440
Epoch 3: Loss=0.1617, Val Acc=0.6519
Epoch 4: Loss=0.0583, Val Acc=0.6565
Epoch 5: Loss=0.0446, Val Acc=0.6634
