In [1]:
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.cuda.amp import autocast, GradScaler
from PIL import Image

In [2]:
class CustomDataset(Dataset):
    def __init__(self, txt_file, transform=None):
        self.img_files, self.labels = self.load_img_files(txt_file)
        self.transform = transform
    
    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.img_files[idx])
        
        if img is None:
            raise FileNotFoundError(f"Image {self.img_files[idx]} not found")
        
        img = cv2.resize(img, (256, 256))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Convert RGB to greyscale
        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        gray_img = np.expand_dims(gray_img, axis=2)  # Add channel dimension

        # Stack RGB and greyscale images along the channel dimension
        img = np.concatenate((img, gray_img), axis=2)

        img = Image.fromarray(img)  # Convert numpy array to PIL Image

        if self.transform:
            img = self.transform(img)
        
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, label
    
    def load_img_files(self, filename):
        if not os.path.isfile(filename):
            raise FileNotFoundError(f"{filename} not found")

        with open(filename, 'r') as f:
            lines = f.readlines()
        
        img_files, labels = [], []

        for line in lines:
            fn, label = line.strip().split(' ')
            img_files.append(fn)
            labels.append(int(label))
        
        return img_files, labels

In [3]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [4]:
class MultiChannelConvolution(nn.Module):
    def __init__(self, num_classes=100):
        super(MultiChannelConvolution, self).__init__()
        
        # Three channels (RGB)
        self.conv_rgb = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=7, padding=3)
        
        # Two channels combinations
        self.conv_rg = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=7, padding=3)
        self.conv_rb = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=7, padding=3)
        self.conv_gb = nn.Conv2d(in_channels=2, out_channels=32, kernel_size=7, padding=3)
        
        # One channel
        self.conv_r = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, padding=3)
        self.conv_g = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, padding=3)
        self.conv_b = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, padding=3)
        self.conv_gray = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=7, padding=3)
        
        self.bn1 = nn.BatchNorm2d(256)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # ResNet block
        self.resnet_block = BasicBlock(in_channels=256, out_channels=256, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x):
        # Split the channels
        rgb = x[:, :3, :, :]      # RGB channels
        gray = x[:, 3:, :, :]     # Greyscale channel
        
        r = rgb[:, 0:1, :, :]
        g = rgb[:, 1:2, :, :]
        b = rgb[:, 2:3, :, :]

        # Apply convolutions
        out_rgb = self.conv_rgb(rgb)
        
        out_rg = self.conv_rg(torch.cat([r, g], dim=1))
        out_rb = self.conv_rb(torch.cat([r, b], dim=1))
        out_gb = self.conv_gb(torch.cat([g, b], dim=1))
        
        out_r = self.conv_r(r)
        out_g = self.conv_g(g)
        out_b = self.conv_b(b)
        out_gray = self.conv_gray(gray)

        # Collect all results
        conv_out = torch.cat([
            out_rgb, out_rg, out_rb, out_gb,
            out_r, out_g, out_b, out_gray
        ], dim=1)
        conv_out = self.bn1(conv_out)
        conv_out = self.maxpool(conv_out)
        conv_out = torch.relu(conv_out)

        # Apply the ResNet block
        res_out = self.resnet_block(conv_out)
        res_out = self.avgpool(res_out)

        # Flatten the output for the fully connected layer
        res_out_flat = res_out.view(res_out.size(0), -1)
        
        # Apply the classifier
        class_out = self.fc(res_out_flat)

        return class_out

In [5]:
# Function to find the latest checkpoint
def find_latest_checkpoint(model_name):
    checkpoints = [f for f in os.listdir(model_name) if f.startswith(model_name) and f.endswith('.pth')]
    if not checkpoints:
        return None, 0
    checkpoints.sort()
    latest_checkpoint = checkpoints[-1]
    epoch = int(latest_checkpoint.split('_epoch')[1].split('.')[0])
    return os.path.join(model_name, latest_checkpoint), epoch

In [6]:
# Check if GPU is available
num_epochs = 50
batch_size = 16
lr = 0.0002

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406, 0.5], std=[0.229, 0.224, 0.225, 0.5])
])

# Load the training and validation datasets
train_dataset = CustomDataset(txt_file='train.txt', transform=transform)
val_dataset = CustomDataset(txt_file='val.txt', transform=transform)
test_dataset = CustomDataset(txt_file='test.txt', transform=transform)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

####################################################################################################
# Initialize the ResNet model, loss function, and optimizer
num_classes = len(set(train_dataset.labels))  # Number of unique classes in the dataset
model = MultiChannelConvolution(num_classes).to(device)
model_name = 'TwoLayer'
os.makedirs(model_name, exist_ok=True)

latest_checkpoint, start_epoch = find_latest_checkpoint(model_name)

if latest_checkpoint:
    model.load_state_dict(torch.load(latest_checkpoint))
    print(f"Loaded checkpoint '{latest_checkpoint}' (epoch {start_epoch})")
else:
    start_epoch = 0

# Create log files
batch_loss_log_path = os.path.join(model_name, f'{model_name}-batch_loss_log.txt')
batch_loss_log = open(batch_loss_log_path, "w")
epoch_log_path = os.path.join(model_name, f'{model_name}-epoch_log.txt')
epoch_log = open(epoch_log_path, "w")
test_log_path = os.path.join(model_name, f'{model_name}-test_log.txt')
test_log = open(epoch_log_path, "w")

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

if latest_checkpoint:
    optimizer_state_path = latest_checkpoint.replace('.pth', '_optimizer.pth')
    if os.path.exists(optimizer_state_path):
        optimizer.load_state_dict(torch.load(optimizer_state_path))

# Training loop
for epoch in range(start_epoch, num_epochs):
    print(f'Epoch {epoch + 1}, start')
    model.train()
    running_loss = 0.0

    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
        
        if i % 10 == 9:    # print every 10 mini-batches
            batch_loss_log.write(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {loss.item():.3f}\n')
            running_loss = 0.0
    
    model_save_path = os.path.join(model_name, f'{model_name}_epoch{epoch+1:02}.pth')
    torch.save(model.state_dict(), model_save_path)
    
    # Validation loop
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels.data)
    
    epoch_val_loss = val_loss / len(val_loader.dataset)
    epoch_val_acc = correct.double() / len(val_loader.dataset)
    epoch_log.write(f'Epoch {epoch + 1}, Validation Loss: {epoch_val_loss:.3f}, Validation Accuracy: {epoch_val_acc:.3f}\n')
    print(f'Validation Loss: {epoch_val_loss:.3f}, Validation Accuracy: {epoch_val_acc:.3f}')

model.eval()
test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        test_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels.data)

epoch_test_loss = test_loss / len(test_loader.dataset)
epoch_test_acc = correct.double() / len(test_loader.dataset)
test_log.write(f'Epoch {num_epochs}, Test Loss: {epoch_test_loss:.3f}, Test Accuracy: {epoch_test_acc:.3f}\n')
print(f'Test Loss: {epoch_test_loss:.3f}, Test Accuracy: {epoch_test_acc:.3f}')

print('Finished Training')

# Close log files
batch_loss_log.close()
epoch_log.close()

Using device: cuda
Loaded checkpoint 'TwoLayer/TwoLayer_epoch28.pth' (epoch 28)
Epoch 29, start


  return F.conv2d(input, weight, bias, self.stride,


Validation Loss: 1.690, Validation Accuracy: 0.513
Epoch 30, start
Validation Loss: 1.676, Validation Accuracy: 0.542
Epoch 31, start
Validation Loss: 1.691, Validation Accuracy: 0.531
Epoch 32, start
Validation Loss: 1.694, Validation Accuracy: 0.540
Epoch 33, start
Validation Loss: 1.642, Validation Accuracy: 0.538
Epoch 34, start
Validation Loss: 1.663, Validation Accuracy: 0.549
Epoch 35, start
Validation Loss: 1.676, Validation Accuracy: 0.538
Epoch 36, start
Validation Loss: 1.643, Validation Accuracy: 0.547
Epoch 37, start
Validation Loss: 1.666, Validation Accuracy: 0.551
Epoch 38, start
Validation Loss: 1.643, Validation Accuracy: 0.564
Epoch 39, start
Validation Loss: 1.667, Validation Accuracy: 0.533
Epoch 40, start
Validation Loss: 1.592, Validation Accuracy: 0.571
Epoch 41, start
Validation Loss: 1.636, Validation Accuracy: 0.547
Epoch 42, start
Validation Loss: 1.674, Validation Accuracy: 0.531
Epoch 43, start
Validation Loss: 1.613, Validation Accuracy: 0.549
Epoch 44, s

KeyboardInterrupt: 