In [None]:
import torch
import torch.nn as nn
import torchvision.models as models


class ConvNext(nn.Module):
	def __init__(self, num_classes=4):
		super(ConvNext, self).__init__()
		
		self.convnext = models.convnext_small(pretrained=True)
		self.convnext2 = models.convnext_small(pretrained=True)
		self.convnext3 = models.convnext_small(pretrained=True)
		self.convnext4 = models.convnext_small(pretrained=True)

		# Replace the classifiers with identities to extract features
		self.convnext.classifier[2] = nn.Identity()
		self.convnext2.classifier[2] = nn.Identity()
		self.convnext3.classifier[2] = nn.Identity()
		self.convnext4.classifier[2] = nn.Identity()
		
		self.fc = nn.Sequential(
			nn.Linear(3072  , 512),
			nn.Dropout(0.2),
			nn.ReLU(),
			nn.Linear(512, num_classes)
		)

	def forward(self, image,image2):
		B, C, H, W = image.size()
		
		if H % 2 != 0 or W % 2 != 0:
			raise ValueError("Image height and width must be divisible by 2.")
		
		top_left = image[:, :, :H//2, :W//2]
		top_right = image[:, :, :H//2, W//2:]
		bottom_left = image[:, :, H//2:, :W//2]
		bottom_right = image[:, :, H//2:, W//2:]
		
		features1 = self.convnext(top_left)
		features2 = self.convnext2(top_right)
		features3 = self.convnext3(bottom_left)
		features4 = self.convnext4(bottom_right)
		
		combined_features = torch.cat((features1, features2, features3, features4), dim=1)
		
		output = self.fc(combined_features.view(combined_features.size(0), -1))
		
		return output



class SWIN(nn.Module):
	def __init__(self, num_classes=4):
		super(SWIN, self).__init__()
		
		self.swin = models.swin_b(pretrained=True)
		self.swin2 = models.swin_b(pretrained=True)
		self.swin3 = models.swin_b(pretrained=True)
		self.swin4 = models.swin_b(pretrained=True)

		# Replace the classifiers with identities to extract features
		self.swin.head = nn.Identity()
		self.swin2.head = nn.Identity()
		self.swin3.head = nn.Identity()
		self.swin4.head = nn.Identity()
		
		self.fc = nn.Sequential(
			nn.Linear(4096  , 512),
			nn.Dropout(0.2),
			nn.ReLU(),
			nn.Linear(512, num_classes)
		)

	def forward(self, image,image2):
		B, C, H, W = image.size()
		
		if H % 2 != 0 or W % 2 != 0:
			raise ValueError("Image height and width must be divisible by 2.")
		
		top_left = image[:, :, :H//2, :W//2]
		top_right = image[:, :, :H//2, W//2:]
		bottom_left = image[:, :, H//2:, :W//2]
		bottom_right = image[:, :, H//2:, W//2:]
		
		features1 = self.swin(top_left)
		features2 = self.swin2(top_right)
		features3 = self.swin3(bottom_left)
		features4 = self.swin4(bottom_right)
		
		combined_features = torch.cat((features1, features2, features3, features4), dim=1)
		
		output = self.fc(combined_features.view(combined_features.size(0), -1))
		
		return output
	

In [None]:

class CombinedModel(nn.Module):
	def __init__(self, num_classes=4):
		super(CombinedModel, self).__init__()
		
		self.convNext = ConvNext()
		self.swin = SWIN()
		device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
		
		self.convNext.load_state_dict(torch.load('./Saved/4x-ConvNext-E10-448.pt', map_location=device))
		self.swin.load_state_dict(torch.load('./Saved/4x-SWIN-E16-448.pt', map_location=device))
		
		# Replace the classifiers with identities to extract features
		self.convNext.fc = nn.Identity()
		self.swin.fc = nn.Identity()


		self.fc = nn.Sequential(
            nn.Linear(7168, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

	def forward(self, image1):
		features1 = self.convNext(image1, '')
		features3 = self.swin(image1, '')
		combined_features = torch.cat((features1,features3), dim=1)		
		
		output = self.fc(combined_features)
		
		return output




In [None]:
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import cv2 as cv

transform = transforms.Compose([
	transforms.Resize((448, 448)),
	transforms.ToTensor(),
	transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
])


class CustomImageDataset(Dataset):
	def __init__(self, image_dir, transform=None):
		self.image_dir = image_dir
		self.transform = transform
		self.image_labels = []
		for folder in os.listdir(image_dir):
			if '1' in folder:
				label = 0
			elif '2' in folder:
				label = 1
			elif '4' in folder:
				label = 2
			else:
				label = 3

			for sub_folder in tqdm(os.listdir(os.path.join(image_dir, folder))):
				files = os.listdir(os.path.join(image_dir, folder, sub_folder))
				if len(files) == 2:
					if 'CC' in files[0]:
						file1 = os.path.join(image_dir, folder, sub_folder, files[0])
						file2 = os.path.join(image_dir, folder, sub_folder, files[1])
					else:
						file1 = os.path.join(image_dir, folder, sub_folder, files[1])
						file2 = os.path.join(image_dir, folder, sub_folder, files[0])
				else:
					file1 = os.path.join(image_dir, folder, sub_folder, files[0])
					file2 = os.path.join(image_dir, folder, sub_folder, files[0])

				image1PIL = cv.imread(file1)
				image2PIL = cv.imread(file2)
				
				combined_image = cv.hconcat([image1PIL, image2PIL])
				label = int(label)  
				combined_image = Image.fromarray(combined_image)
				if self.transform:
					image1 = self.transform(combined_image)
				
				self.image_labels.append((image1, image1, label))
	
	def __len__(self):
		return len(self.image_labels)
	
	def __getitem__(self, idx):
		image1, image2, label = self.image_labels[idx]
		
		
		return image1, label

trainPath = "./Data-448/"
valPath = "./Data-448-Val/"

print("Creating Custom Data")
train_dataset = CustomImageDataset(image_dir=trainPath, transform=transform)
val_dataset = CustomImageDataset(image_dir=valPath, transform=transform)


def get_batches(dataset, batch_size, shuffle=True):
	indices = np.arange(len(dataset))
	if shuffle:
		np.random.shuffle(indices)
	for start_idx in range(0, len(dataset), batch_size):
		batch_indices = indices[start_idx:start_idx + batch_size]
		batch = [dataset[idx] for idx in batch_indices]
		image1s, labels = zip(*batch)
		image1s = torch.stack(image1s)
		labels = torch.tensor(labels)
		yield image1s, labels


In [None]:
def freeze_layers(model):
    for param in model.convNext.parameters():
        param.requires_grad = False
    for param in model.swin.parameters():
        param.requires_grad = False

def unfreeze_layers(model):
    for param in model.convNext.parameters():
        param.requires_grad = True
    for param in model.swin.parameters():
        param.requires_grad = True

In [None]:
from sklearn.utils.class_weight import compute_class_weight

y_train = [label for _, label in train_dataset]
device = torch.device(3)
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights)

print(class_weights)

model = CombinedModel(num_classes=4)

# optimizer = torch.optim.SGD(model.parameters(), lr=3e-4, momentum=0.9)
optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)

		
def train_model(model, train_dataset, val_dataset, criterion, optimizer, num_epochs=10, batch_size=16):
	
	model.to(device)
	freeze_layers(model)
	
	for epoch in range(num_epochs):
		model.train()
		running_loss = 0.0
		
		for i, (image1s, labels) in enumerate(get_batches(train_dataset, batch_size, shuffle=True)):
			image1s, labels = image1s.to(device), labels.to(device)
			
			optimizer.zero_grad()
			
			outputs = model(image1s)
			labels = labels.long()  # Ensure labels are LongTensor
			loss = criterion(outputs, labels)
			
			loss.backward()
			optimizer.step()
			
			# Print loss for this batch
			print(f'Epoch {epoch+1}/{num_epochs}, Step {i+1}, Loss: {loss.item():.6f}')
			
			running_loss += loss.item() * image1s.size(0) 

		# Calculate and print average loss for this epoch
		epoch_loss = running_loss / len(train_dataset)
		print(f'Epoch {epoch+1}/{num_epochs}, Average Loss: {epoch_loss:.4f}')
		
		# Validation (optional, to track performance on validation data)
		model.eval()
		val_loss = 0.0
		correct = 0
		total = 0
		with torch.no_grad():
			for image1s, labels in get_batches(val_dataset, batch_size, shuffle=False):
				image1s, labels = image1s.to(device), labels.to(device)
				outputs = model(image1s)
				labels = labels.long()  # Ensure labels are LongTensor
				loss = criterion(outputs, labels)
				
				val_loss += loss.item() * image1s.size(0)
				_, predicted = torch.max(outputs, 1)
				total += labels.size(0)
				correct += (predicted == labels).sum().item()
		
		val_loss /= len(val_dataset)
		val_accuracy = correct / total * 100
		print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')
		torch.save(model.state_dict(), './Models/model' + str(epoch + 3) + '.pt')

train_model(model, train_dataset, val_dataset, criterion, optimizer, num_epochs=1, batch_size=32)


In [None]:
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, f1_score

model = CombinedModel(num_classes=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model.load_state_dict(torch.load('.\Models\model3.pt', map_location=device))
def evaluate_model(model, testData, device):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for image1s, labels in get_batches(testData, 8, shuffle=False):
            image1s, labels = image1s.to(device), labels.to(device)
            outputs = model(image1s)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    cm = confusion_matrix(all_labels, all_preds)

    # Compute F1 score
    f1 = f1_score(all_labels, all_preds, average='weighted')

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

    print(f'F1 Score: {f1:.4f}')

    return cm, f1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
evaluate_model(model, val_dataset, device)
