# Relation Networks

Relation Networks (RNs) are designed for few-shot learning by explicitly learning relationships between query and support examples. They use a neural network to dynamically model similarity, enabling flexible and effective classification with minimal labeled data.

##### Step 1: Import all necessary libraries

In [13]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
import os


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### Step 2: We will load RHA dataset, tranformed in .npy format using helper script.  
In helper script, we are just loading data in size format: [total_number,character,64,64].                             

In [3]:

# # Function to create dataset from folders
# def create_npy_from_folders(data_dir, img_size=(224, 224)):
#     all_classes = sorted(os.listdir(data_dir))
#     all_images = []
#     min_samples = float('inf')

#     for class_folder in all_classes:
#         class_path = os.path.join(data_dir, class_folder)
#         if not os.path.isdir(class_path):
#             continue
#         class_images = []
#         for img_file in sorted(os.listdir(class_path)):
#             img_path = os.path.join(class_path, img_file)
#             if img_file.endswith(('.png', '.jpg', '.jpeg')):
#                 try:
#                     img = Image.open(img_path).convert('L')
#                     img = img.resize(img_size)
#                     class_images.append(np.array(img))
#                 except Exception as e:
#                     print(f"Error loading image {img_path}: {e}")
#                     continue
#         min_samples = min(min_samples, len(class_images))
#         all_images.append(class_images)

#     truncated_images = [class_images[:min_samples] for class_images in all_images]
#     dataset = np.array(truncated_images)
#     np.save('data.npy', dataset)
#     return dataset

# # Load the dataset
# data_dir = '/content/drive/MyDrive/_projects/GEI_Project/Dataset_fewshot'
# dataset = create_npy_from_folders(data_dir, img_size=(224, 224))

# print(f"Dataset shape: {dataset.shape}")

# # Split each class separately
# x_train, x_val, x_test = [], [], []
# y_train, y_val, y_test = [], [], []

# for class_idx in range(dataset.shape[0]):
#     class_samples = dataset[class_idx]
#     train, temp = train_test_split(class_samples, test_size=0.3, random_state=42)
#     val, test = train_test_split(temp, test_size=0.5, random_state=42)

#     x_train.append(train)
#     x_val.append(val)
#     x_test.append(test)

#     y_train.extend([class_idx] * len(train))
#     y_val.extend([class_idx] * len(val))
#     y_test.extend([class_idx] * len(test))

# # Convert lists to numpy arrays
# x_train = np.array(x_train)
# x_val = np.array(x_val)
# x_test = np.array(x_test)
# y_train = np.array(y_train)
# y_val = np.array(y_val)
# y_test = np.array(y_test)

# print(f"x_train shape: {x_train.shape}, y_train shape: {y_train.shape}")
# print(f"x_val shape: {x_val.shape}, y_val shape: {y_val.shape}")
# print(f"x_test shape: {x_test.shape}, y_test shape: {y_test.shape}")

Let's Visualize example of each class.

In [4]:
# for class_idx in range(x_train.shape[0]):  # Loop over classes (3 classes in your case)
#     print("class " + str(class_idx))
#     # Extract the first image from the current class
#     first_image = x_train[class_idx, 0, :, :]  # (height, width)

#     # Plot the first image of each class
#     plt.figure()
#     plt.imshow(first_image, cmap='gray')  # Display as grayscale
#     plt.title(f"Class {class_idx}, Image 0")  # Add a title with class index and image number
#     plt.axis('off')  # Hide axis for better visualization
#     plt.show()

###### Step 3: Training Data Processing
To Load dataset, and prepare it for Relation Networks Architecture, we need to create:
1. Label Set: Variable choose_label
2. Support Set: support_set_x, support_set_y
3. Batch from Suppport Set Examples

Let's first create a batch which can give a support set, and target set.

##### Step 3: Create a Relation Network

In [5]:
# Custom dataset for loading images from folders
class FewShotDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = os.listdir(root_dir)
        self.image_paths = []
        self.labels = []

        for label, class_dir in enumerate(self.classes):
            class_path = os.path.join(root_dir, class_dir)
            for img_name in os.listdir(class_path):
                self.image_paths.append(os.path.join(class_path, img_name))
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Feature extractor (e.g., CNN)
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

    def forward(self, x):
        out = self.conv(x)
        return out.view(x.size(0), -1)  # Flatten

# Relation module
class RelationModule(nn.Module):
    def __init__(self, input_dim):
        super(RelationModule, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.fc(x)

# Few-shot relation network
class RelationNetwork(nn.Module):
    def __init__(self, feature_dim, relation_dim):
        super(RelationNetwork, self).__init__()
        self.feature_extractor = FeatureExtractor()
        self.relation_module = RelationModule(feature_dim + relation_dim)

    def forward(self, support, query):
        support_features = self.feature_extractor(support)  # (N_way * K_shot, feature_dim)
        query_features = self.feature_extractor(query)      # (N_query, feature_dim)

        N_query = query_features.size(0)
        N_support = support_features.size(0)

        query_features = query_features.unsqueeze(1).expand(-1, N_support, -1)
        support_features = support_features.unsqueeze(0).expand(N_query, -1, -1)

        combined = torch.cat((query_features, support_features), dim=-1)
        relations = self.relation_module(combined.view(-1, combined.size(-1)))
        return relations.view(N_query, N_support)


##### Step 3: Training the Relation Network

In [6]:
width = 224
height = 224
data_path = '/content/drive/MyDrive/_projects/GEI_Project/Dataset_fewshot'
total_epoches = 50

# Metrics storage
train_loss = []
train_accuracy = []

In [9]:
# Training example
def train_relation_network():
    # Define image transformations
    transform = transforms.Compose([
        transforms.Resize((width, height)),
        transforms.ToTensor(),
    ])

    # Load dataset
    dataset = FewShotDataset(root_dir=data_path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

    # Initialize model
    feature_dim = 64 * 56 * 56  # Derived from conv + pooling
    model = RelationNetwork(feature_dim=feature_dim, relation_dim=feature_dim)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.BCELoss()

    for epoch in range(total_epoches):
        epoch_loss = 0
        correct = 0
        total = 0

        for batch in dataloader:
            support, query = batch[0][:8], batch[0][8:]
            support_labels, query_labels = batch[1][:8], batch[1][8:]

            relations = model(support, query)
            target = (query_labels.unsqueeze(1) == support_labels.unsqueeze(0)).float()

            loss = criterion(relations, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            # Calculate accuracy
            predicted = (relations > 0.5).float()
            correct += (predicted == target).sum().item()
            total += target.numel()

        epoch_accuracy = correct / total
        print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

        if epoch_accuracy > 0.7:
          break


In [10]:

# Run training
train_relation_network()


Epoch 1, Loss: 616.3077, Accuracy: 0.6712
Epoch 2, Loss: 610.9375, Accuracy: 0.6807
Epoch 3, Loss: 637.5000, Accuracy: 0.6635
Epoch 4, Loss: 639.0625, Accuracy: 0.6678
Epoch 5, Loss: 632.8125, Accuracy: 0.6661
Epoch 6, Loss: 637.5000, Accuracy: 0.6661
Epoch 7, Loss: 635.9375, Accuracy: 0.6695
Epoch 8, Loss: 578.1250, Accuracy: 0.7012
Epoch 9, Loss: 610.9375, Accuracy: 0.6755
Epoch 10, Loss: 620.3125, Accuracy: 0.6755
Epoch 11, Loss: 614.0625, Accuracy: 0.6815
Epoch 12, Loss: 632.8125, Accuracy: 0.6635
Epoch 13, Loss: 609.3750, Accuracy: 0.6866
Epoch 14, Loss: 646.8750, Accuracy: 0.6661
Epoch 15, Loss: 576.5625, Accuracy: 0.6918
Epoch 16, Loss: 657.8125, Accuracy: 0.6473
Epoch 17, Loss: 670.3125, Accuracy: 0.6635
Epoch 18, Loss: 678.1250, Accuracy: 0.6336
Epoch 19, Loss: 585.9375, Accuracy: 0.6892
Epoch 20, Loss: 626.5625, Accuracy: 0.6721
Epoch 21, Loss: 623.4375, Accuracy: 0.6712
Epoch 22, Loss: 618.7500, Accuracy: 0.6738
Epoch 23, Loss: 678.1250, Accuracy: 0.6438
Epoch 24, Loss: 604.

KeyboardInterrupt: 

#### Step 4: Testing the Model

##### Let's Run Experiments !!!!!

Now Let's obtain our test accuracy by running the following code block:

#### Step 5: Let's visualize our results

In [14]:
# Function to plot loss and accuracy
def plot_metrics(train, val=None, name1="train", name2="val", title=""):
    plt.figure()
    plt.title(title)
    plt.plot(train, label=name1)
    if val is not None:
        plt.plot(val, label=name2)
    plt.legend()
    plt.xlabel("Epoch")
    plt.ylabel(title.split()[0])  # Use "Loss" or "Accuracy" as y-label
    plt.show()

    # Plot loss and accuracy
    plot_metrics(train_loss, title="Loss Graph")
    plot_metrics(train_accuracy, title="Accuracy Graph")