In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, Dataset

#UCMerced
#Elastic WEight consolidation

In [16]:
class ResidualAdapter(nn.Module):
    def __init__(self, in_channels, reduction=32):
        super().__init__()

        bottleneck = in_channels // reduction   # reduced channels

        self.adapter = nn.Sequential(
            nn.Conv2d(in_channels, bottleneck, kernel_size=1, bias=False),
            nn.BatchNorm2d(bottleneck),
            nn.ReLU(inplace=True),
            nn.Conv2d(bottleneck, in_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(in_channels)
        )

    def forward(self, x):
        return self.adapter(x)


class AdapterBlock(nn.Module):
    def __init__(self, block, channels):
        super().__init__()
        self.block = block
        self.bn = nn.BatchNorm2d(channels)
        self.adapter = ResidualAdapter(channels)

    def forward(self, x):
        block_out = self.block(x)
        normed = self.bn(block_out)
        return block_out + self.adapter(normed)
    
class ResNetWithAdapters(nn.Module):
    def __init__(self, base, domain_list, domain_num_classes):
        super().__init__()

        # Shared stem
        self.stem = nn.Sequential(
            base.conv1,
            base.bn1,
            base.relu,
            base.maxpool
        )
        self.base_layers = nn.ModuleDict({
            'layer1': base.layer1,
            'layer2': base.layer2,
            'layer3': base.layer3,
            'layer4': base.layer4
        })
        
                # Shared layers (no adapters)
        self.layer1 = base.layer1
        self.layer2 = base.layer2

        # Domain-specific adapters only for last 2 layers
        self.adapters = nn.ModuleDict({
            domain: nn.ModuleDict({
                'layer3': self._wrap_with_adapters(base.layer3, 1024),
                'layer4': self._wrap_with_adapters(base.layer4, 2048)
            })
            for domain in domain_list
        })

        # Classifiers per domain
        self.classifiers = nn.ModuleDict({
            domain: nn.Linear(2048, domain_num_classes[domain])
            for domain in domain_list
        })

        self.avgpool = base.avgpool

    def _wrap_with_adapters(self, layer, channels):
        # Wrap each block with AdapterBlock
        return nn.Sequential(
            *[AdapterBlock(block, channels) for block in layer]
        )

    def forward(self, x, domain):
        x = self.stem(x)

        # Shared
        x = self.layer1(x)
        x = self.layer2(x)

        # Domain-specific adapters
        x = self.adapters[domain]['layer3'](x)
        x = self.adapters[domain]['layer4'](x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return self.classifiers[domain](x)


In [3]:
from itertools import islice
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader, IterableDataset
from sklearn.model_selection import train_test_split

import torchvision.transforms as transforms

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])





In [4]:
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset

class HFDatasetWrapper(Dataset):
    def __init__(self, dataset, transform=None, label_key='label', label_to_idx=None):
        self.dataset = dataset
        self.transform = transform
        self.label_key = label_key
        self.label_to_idx = label_to_idx

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = np.array(item['image'])
        image = Image.fromarray(image)
    
        if self.transform:
            image = self.transform(image)
    
        label = item[self.label_key]
    
        if self.label_to_idx is not None:
            if isinstance(label, list):  
                label_tensor = torch.zeros(len(self.label_to_idx), dtype=torch.float32)
                for lbl in label:
                    label_tensor[self.label_to_idx[lbl]] = 1.0
            else:  
                label_tensor = torch.tensor(self.label_to_idx[label], dtype=torch.long)
        else:
            label_tensor = torch.tensor(label, dtype=torch.long)  
    
        return image, label_tensor



In [5]:
import os
num_cpus = os.cpu_count()


In [6]:
dataset_1 = load_dataset("blanchon/EuroSAT_RGB", split="train")

split_dataset_1 = dataset_1.train_test_split(test_size=0.6, seed=42)


euroSAT_train = split_dataset_1['train'] 
euroSAT_test = split_dataset_1['test']    

euroSAT_train = HFDatasetWrapper(euroSAT_train, transform=train_transform)
euroSAT_test = HFDatasetWrapper(euroSAT_test, transform=test_transform)

euroSAT_train_loader = DataLoader(euroSAT_train, batch_size=64, shuffle=True, num_workers = num_cpus // 2,persistent_workers=True,pin_memory=True)
euroSAT_test_loader = DataLoader(euroSAT_test, batch_size=64, shuffle=False,num_workers = num_cpus // 2,persistent_workers=True,pin_memory=True)




In [7]:
dataset_2 = load_dataset("blanchon/PatternNet", split="train")
split_dataset_2 = dataset_2.train_test_split(test_size=0.6, seed=42)

patternNet_train = split_dataset_2['train']
patternNet_test = split_dataset_2['test']

patternNet_train = HFDatasetWrapper(patternNet_train, transform=train_transform)
patternNet_test = HFDatasetWrapper(patternNet_test, transform=test_transform)

patternNet_train_loader = DataLoader(patternNet_train, batch_size=64, shuffle=True, num_workers = num_cpus // 2,persistent_workers=True,pin_memory=True)
patternNet_test_loader = DataLoader(patternNet_test, batch_size=64, shuffle=False, num_workers = num_cpus // 2,persistent_workers=True,pin_memory=True)


In [8]:
# dataset_3 = load_dataset("blanchon/RESISC45", split="train")

# split_dataset_3 = dataset_3.train_test_split(test_size=0.4, seed=42)


# RESISC_train = split_dataset_3['train']8
# RESISC_test = split_dataset_3['test']




In [9]:
# !unzip /home/23ucs712/MLRSNet-master.zip -d /home/23ucs712/

In [10]:
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

data_dir = "/home/23ucs712/MLRSNet-master/Images"



dataset = ImageFolder(root=data_dir, transform=test_transform)
print("Total samples:", len(dataset))
print("Classes:", len(dataset.classes))


Total samples: 111666
Classes: 46


In [11]:
train_size = int(0.6 * len(dataset))
test_size = len(dataset) - train_size


train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

MLRS_train_loader = DataLoader(train_dataset, batch_size=64,num_workers=num_cpus//2,shuffle=True)
MLRS_test_loader = DataLoader(test_dataset, batch_size=64,num_workers=num_cpus//2,shuffle=False)





In [12]:
# pip install librosa soundfile


In [13]:
domain_num_classes = {
    'EuroSAT': 10,
    'PatternNet': 38,
#    'RESISC45': 45,
    'MLRS': 46,
    'Advance': 13

    
}


In [None]:
dataset_5 = load_dataset("blanchon/ADVANCE", split='train')
split_dataset_5 = dataset_5.train_test_split(test_size=0.2, seed=42)

advance_train = split_dataset_5['train']
advance_test = split_dataset_5['test']

all_labels = set(example['label'] for example in advance_train)
sorted_labels = sorted(list(all_labels))
label_to_idx = {label: idx for idx, label in enumerate(sorted_labels)} 

wrapped_advance_train = HFDatasetWrapper(
    advance_train,
    transform=train_transform,
    label_key='label',
    label_to_idx=label_to_idx
)

wrapped_advance_test = HFDatasetWrapper(
    advance_test,
    transform=test_transform,
    label_key='label',
    label_to_idx=label_to_idx
)

from torch.utils.data import WeightedRandomSampler

labels = [label_to_idx[example['label']] for example in advance_train]
class_sample_counts = [labels.count(i) for i in range(len(sorted_labels))]
weights = [1.0 / class_sample_counts[label] for label in labels]

sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

advance_train_loader = DataLoader(
    wrapped_advance_train,
    batch_size=64,
    sampler=sampler,
    num_workers=num_cpus // 2,
    persistent_workers=True,
    pin_memory=True
)

advance_test_loader = DataLoader(
    wrapped_advance_test,
    batch_size=64,
    shuffle=False,
    num_workers=num_cpus // 2,
    persistent_workers=True,
    pin_memory=True
)





In [17]:
from torchvision.models import resnet50

base = resnet50(weights='IMAGENET1K_V1')

domain_list = ['EuroSAT','PatternNet','MLRS','Advance']
model = ResNetWithAdapters(base,domain_list,domain_num_classes)

# Freeze ResNet backbone
for param in model.stem.parameters():
    param.requires_grad = False

for layer in model.base_layers.values():
    for param in layer.parameters():
        param.requires_grad = False

for param in model.avgpool.parameters():
    param.requires_grad = False


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(device)


In [18]:
train_loaders = {
    'EuroSAT': euroSAT_train_loader,
    'PatternNet': patternNet_train_loader,
#    'RESISC45': RESISC_train,
    'MLRS': MLRS_train_loader,
    'Advance': advance_train_loader
}

test_loaders = {
    'EuroSAT': euroSAT_test_loader,
    'PatternNet': patternNet_test_loader,
#    'RESISC45': RESISC_test,
    'MLRS': MLRS_test_loader,
    'Advance': advance_test_loader

}


In [19]:
def evaluate(model, test_loaders, domain_num_classes):
    model.eval()
    
    with torch.no_grad():
        correct = 0
        total = 0
        print(f"\nEvaluating on domain: {domain}")
        progress_bar = tqdm(test_loaders[domain], desc=f"{domain} Eval", leave=False)

        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images, domain)

            # Optional safety check
            if outputs.shape[1] != domain_num_classes[domain]:
                print(f"[WARNING] Output dim {outputs.shape[1]} does not match expected {domain_num_classes[domain]} for {domain}")

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            progress_bar.set_postfix(acc=f"{100 * correct / total:.2f}%")

        accuracy = 100 * correct / total
        print(f"Final Accuracy on {domain}: {accuracy:.2f}%")
        



In [20]:
def freeze_domain(model, current_domain):
    for name, param in model.named_parameters():
        if f".{current_domain}." in name and ("adapter" in name or "classifier" in name):
            param.requires_grad = True
        else:
            param.requires_grad = False


In [21]:
def domain_parameters(model,domain):
    return list(model.adapters[domain].parameters()) + list(model.classifiers[domain].parameters())


In [24]:
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen = total - trainable

    print(f"Total parameters:     {total:,}")
    print(f"Trainable parameters: {trainable:,}")
    print(f"Frozen parameters:    {frozen:,}")
count_parameters(model)


Total parameters:     28,645,547
Trainable parameters: 1,250,058
Frozen parameters:    27,395,489


In [25]:
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR


optimizer = {}
scheduler = {}

for domain in domain_list:
    optimizer[domain] = optim.Adam(domain_parameters(model, domain), lr=1e-3)
    scheduler[domain] = StepLR(optimizer[domain], step_size=15, gamma=0.1)
    
num_epochs = 20


for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_samples = 0

    print(f"\nEpoch [{epoch+1}/{num_epochs}]")

    for domain, loader in train_loaders.items():
        criterion = nn.CrossEntropyLoss()
        print(f"\nTraining on domain: {domain}")
        domain_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(loader, desc=f"[{domain}]", leave=True)

        freeze_domain(model, domain)

        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)

            optimizer[domain].zero_grad()
            outputs = model(images, domain)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer[domain].step()

            batch_size = images.size(0)
            domain_loss += loss.item() * batch_size
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%' 
            })

        avg_domain_loss = domain_loss / len(loader.dataset)
        avg_domain_acc = 100 * correct / total
        print(f"Domain: {domain}, Epoch Loss: {avg_domain_loss:.4f}, Accuracy: {avg_domain_acc:.2f}%")

        scheduler[domain].step()

        evaluate(model, test_loaders, domain_num_classes)

    avg_total_loss = total_loss / total_samples
    print(f"\nEpoch [{epoch+1}/{num_epochs}] Avg Total Loss: {avg_total_loss:.4f}")



Epoch [1/20]

Training on domain: EuroSAT


[EuroSAT]:   0%|                                                                                | 0/102 [00:05<?, ?it/s]


RuntimeError: DataLoader worker (pid(s) 65124, 65125, 65126, 65127, 65128, 65129, 65130, 65131, 65132, 65133, 65134, 65135, 65136, 65137, 65138, 65139, 65140, 65141, 65143, 65144, 65145, 65146, 65147, 65148, 65149, 65151, 65152, 65153, 65154, 65155, 65156, 65157, 65158, 65159, 65160, 65161, 65162, 65163, 65164, 65165) exited unexpectedly