In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, BatchSampler, RandomSampler
from torch.optim import lr_scheduler

# Define the transformations
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])
train_dataset_1 = datasets.ImageFolder(root="/kaggle/input/icassp/ICASSP_20/spec/train", transform=transform)
test_dataset_1 = datasets.ImageFolder(root="/kaggle/input/icassp/ICASSP_20/spec/test", transform=transform)

train_dataset_2 = datasets.ImageFolder(root="/kaggle/input/icassp/ICASSP_20/4.2b/train", transform=transform)
test_dataset_2 = datasets.ImageFolder(root="/kaggle/input/icassp/ICASSP_20/4.2b/test", transform=transform)

# Print dataset sizes
print(f'Size of train_dataset_1: {len(train_dataset_1)}')
print(f'Size of train_dataset_2: {len(train_dataset_2)}')
print(f'Size of test_dataset_1: {len(test_dataset_1)}')
print(f'Size of test_dataset_2: {len(test_dataset_2)}')

# Custom BatchSampler to ensure equal batch sizes
class CustomBatchSampler(BatchSampler):
    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

# Initialize the DataLoaders with custom BatchSamplers
batch_size =8
train_loader_1 = DataLoader(train_dataset_1, batch_sampler=CustomBatchSampler(RandomSampler(train_dataset_1), batch_size=batch_size, drop_last=True))
train_loader_2 = DataLoader(train_dataset_2, batch_sampler=CustomBatchSampler(RandomSampler(train_dataset_2), batch_size=batch_size, drop_last=True))
test_loader_1 = DataLoader(test_dataset_1, batch_sampler=CustomBatchSampler(RandomSampler(test_dataset_1), batch_size=batch_size, drop_last=True))
test_loader_2 = DataLoader(test_dataset_2, batch_sampler=CustomBatchSampler(RandomSampler(test_dataset_2), batch_size=batch_size, drop_last=True))

# Function to pair loaders
def pair_loaders(loader1, loader2):
    for batch1, batch2 in zip(loader1, loader2):
        yield batch1, batch2

# Load the pre-trained ResNet-18 model
model = models.resnet18(pretrained=False)
for param in model.parameters():
    param.requires_grad = True
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 10)
)

# Modify the first convolutional layer to accept 2 input channels instead of 3
model.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

# Check if GPU is available and move the model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

# Training loop
def train(model, train_loader_1, train_loader_2, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    count = 0
    for (inputs_1, labels_1), (inputs_2, labels_2) in pair_loaders(train_loader_1, train_loader_2):
        # Move inputs and labels to the device
        inputs_1, labels_1 = inputs_1.to(device), labels_1.to(device)
        inputs_2, labels_2 = inputs_2.to(device), labels_2.to(device)

        # Combine the inputs from the two datasets along the channel dimension
        inputs = torch.cat((inputs_1, inputs_2), dim=1)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels_1)  # Assuming the labels are the same for both inputs
        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        count += 1

    if count > 0:
        print(' ')
    else:
        print('No batches processed in training.')

# Testing loop
def test(model, test_loader_1, test_loader_2, criterion, device):
    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0
    count = 0
    with torch.no_grad():
        for (inputs_1, labels_1), (inputs_2, labels_2) in pair_loaders(test_loader_1, test_loader_2):
            inputs_1, labels_1 = inputs_1.to(device), labels_1.to(device)
            inputs_2, labels_2 = inputs_2.to(device), labels_2.to(device)

            inputs = torch.cat((inputs_1, inputs_2), dim=1)

            outputs = model(inputs)
            loss = criterion(outputs, labels_1)
            test_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels_1.size(0)
            correct += (predicted == labels_1).sum().item()
            count += 1

    if count > 0:
#         print(' ')
        print(f'Accuracy: {100 * correct / total}%')
    else:
        print('No batches processed in testing.')

    if scheduler is not None:
        scheduler.step(test_loss)

# Example training and testing loops
num_epochs = 100
for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    train(model, train_loader_1, train_loader_2, criterion, optimizer, device)
    test(model, test_loader_1, test_loader_2, criterion, device)
