In [1]:
import os
import torch
import random
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split, Subset
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import json
from tqdm import tqdm

In [2]:
class VGGFace2Dataset:
    def __init__(self, root_dir, batch_size, split_file="data_split.json"):
        self.root_dir = root_dir
        self.split_file = split_file
        
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        full_dataset = datasets.ImageFolder(root=os.path.join(self.root_dir, 'train'), transform=transform)
        
        train_indices, test_indices = self._get_split_indices(len(full_dataset))

        self.train_dataset = TripletVGGFace2(full_dataset, train_indices)
        self.test_dataset = TripletVGGFace2(full_dataset, test_indices)

        self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False)

    def _get_split_indices(self, dataset_size):
        if os.path.exists(self.split_file):
            with open(self.split_file, 'r') as f:
                split = json.load(f)
            return split['train'], split['test']
        else:
            indices = torch.randperm(dataset_size).tolist()
            train_size = int(dataset_size * 0.8)
            train_indices = indices[:train_size]
            test_indices = indices[train_size:]

            split = {'train': train_indices, 'test': test_indices}
            with open(self.split_file, 'w') as f:
                json.dump(split, f)
            return train_indices, test_indices

class TripletVGGFace2(torch.utils.data.Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices
        self.labels = [dataset.targets[i] for i in indices]
        self.label_to_indices = self._get_label_to_indices()

    def _get_label_to_indices(self):
        label_to_indices = {}
        for idx in self.indices:
            label = self.dataset.targets[idx]
            if label not in label_to_indices:
                label_to_indices[label] = []
            label_to_indices[label].append(idx)
        return label_to_indices

    def __getitem__(self, idx):
        anchor_idx = self.indices[idx]
        anchor_label = self.dataset.targets[anchor_idx]
        anchor_img = self.dataset[anchor_idx][0]

        positive_idx = random.choice([i for i in self.label_to_indices[anchor_label] if i != anchor_idx])
        positive_img = self.dataset[positive_idx][0]

        negative_label = random.choice([l for l in self.label_to_indices.keys() if l != anchor_label])
        negative_idx = random.choice(self.label_to_indices[negative_label])
        negative_img = self.dataset[negative_idx][0]

        return anchor_img, positive_img, negative_img

    def __len__(self):
        return len(self.indices)

In [3]:
class ConvolutionLayer(nn.Module):
    def __init__(self, in_channels=3, out_channels=256, kernel_size=3):
        super(ConvolutionLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=1)

    def forward(self, x):
        return F.relu(self.conv(x))

class PrimaryCaps(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=3, num_routes=32 * 6 * 6):
        super(PrimaryCaps, self).__init__()
        self.num_capsules = num_capsules
        self.out_channels = out_channels
        self.num_routes = num_routes
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=1)
            for _ in range(num_capsules)])

    def forward(self, x):
        batch_size = x.size(0)
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        
        # Get the actual dimensions after convolution
        h, w = u.size(3), u.size(4)
        self.num_routes = h * w
        
        # Reshape to ensure correct dimensions for DigitalCaps
        u = u.view(batch_size, self.num_capsules * self.out_channels, h * w)
        u = u.permute(0, 2, 1).contiguous()
        
        return self.squash(u)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

class DigitalCaps(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitalCaps, self).__init__()
        
        self.num_capsules = num_capsules
        self.num_routes = num_routes
        self.in_channels = in_channels * 8  # This should match PrimaryCaps output
        self.out_channels = out_channels
        
        # Initialize transformation matrix
        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, self.in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        self.num_routes = x.size(1)
        
        # Adjust W tensor size if needed
        if self.W.size(1) != self.num_routes:
            self.W = nn.Parameter(torch.randn(1, self.num_routes, self.num_capsules, 
                                            self.out_channels, x.size(2), device=x.device))
        
        # Prepare input for routing
        x = x.unsqueeze(2).unsqueeze(4)
        W = torch.cat([self.W] * batch_size, dim=0)
        
        # Calculate predictions
        u_hat = torch.matmul(W, x)
        
        # Initialize routing logits
        b_ij = Variable(torch.zeros(batch_size, self.num_routes, self.num_capsules, 1, 1, device=x.device))
        
        # Routing algorithm
        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij, dim=2)
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), v_j)
                b_ij = b_ij + a_ij

        return v_j.squeeze(1).squeeze(-1)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor
    
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
        super(DigitalCaps, self).__init__()

        self.in_channels = in_channels * 8
        self.num_routes = num_routes
        self.num_capsules = num_capsules
        self.out_channels = out_channels

        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, self.in_channels))

    def forward(self, x):
        batch_size = x.size(0)
        # Update num_routes to match the input
        self.num_routes = x.size(1)
        
        # Adjust W tensor size if needed
        if self.W.size(1) != self.num_routes:
            self.W = nn.Parameter(torch.randn(1, self.num_routes, self.num_capsules, self.out_channels, self.in_channels, device=x.device))
        
        x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
        W = torch.cat([self.W] * batch_size, dim=0)
        
        u_hat = torch.matmul(W, x)
        b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1, device=x.device))

        # Rest of the forward method remains the same
        num_iterations = 3
        for iteration in range(num_iterations):
            c_ij = F.softmax(b_ij, dim=1)
            c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
            v_j = self.squash(s_j)
            if iteration < num_iterations - 1:
                a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
                b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)

        return v_j.squeeze(1)

    def squash(self, input_tensor):
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
        return output_tensor

class Decoder(nn.Module):
    def __init__(self, input_width=28, input_height=28, input_channel=1):
        super(Decoder, self).__init__()
        self.input_width = input_width
        self.input_height = input_height
        self.input_channel = input_channel
        self.reconstraction_layers = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, self.input_height * self.input_width * self.input_channel),
            nn.Sigmoid()
        )

    def forward(self, x, data):
        classes = torch.sqrt((x ** 2).sum(2))
        classes = F.softmax(classes, dim=0)

        _, max_length_indices = classes.max(dim=1)
        masked = Variable(torch.sparse.torch.eye(10))
        masked = masked.index_select(dim=0, index=Variable(max_length_indices.squeeze(1).data))
        t = (x * masked[:, :, None, None]).view(x.size(0), -1)
        reconstructions = self.reconstraction_layers(t)
        reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height)
        return reconstructions, masked

class CapsNet(nn.Module):
    def __init__(self, config=None):
        super(CapsNet, self).__init__()
        if config:
            self.conv_layer = ConvolutionLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)
            self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels, config.pc_kernel_size, config.pc_num_routes)
            self.digit_capsules = DigitalCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels, config.dc_out_channels)
            self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels)
            self.embedding_layer = nn.Linear(config.dc_out_channels * config.dc_num_capsules, 128) 
        else:
            self.conv_layer = ConvolutionLayer()
            self.primary_capsules = PrimaryCaps()
            self.digit_capsules = DigitalCaps()
            self.decoder = Decoder()
            self.embedding_layer = nn.Linear(16 * 10, 128)

    def forward(self, data):
        output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
        embeddings = self.embedding_layer(output.view(output.size(0), -1))
        reconstructions, masked = self.decoder(output, data)
        return embeddings, reconstructions, masked

    def compute_triplet_loss(self, anchor, positive, negative, margin=1.0):
        dist_pos = F.pairwise_distance(anchor, positive, p=2)
        dist_neg = F.pairwise_distance(anchor, negative, p=2)
        loss = F.relu(dist_pos - dist_neg + margin)
        return loss.mean()

In [4]:
class Config:
    def __init__(self, dataset='vgg'):
        if dataset == 'vgg':
            self.cnn_in_channels = 3
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 3

            # Primary Capsule
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 3
            # Adjusting num_routes for larger input size
            self.pc_num_routes = 32 * 32  # Adjusted for 224x224 input with stride 2

            # Digit Capsule
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 32  # Same adjustment
            self.dc_in_channels = 32
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 224
            self.input_height = 224 

In [5]:
def save_model(model, filename="model_weights.pth"):
    os.makedirs("modelcheckpoints", exist_ok=True)
    torch.save(model.state_dict(), os.path.join("modelcheckpoints", filename))

def train_triplet(model, optimizer, train_loader, epoch, triplet_loss):
    model.train()
    total_loss = 0
    for batch_id, (anchor_data, positive_data, negative_data) in enumerate(tqdm(train_loader)):
        anchor_data = anchor_data.to(device)
        positive_data = positive_data.to(device)
        negative_data = negative_data.to(device)
        optimizer.zero_grad()
        anchor_embedding, _, _ = model(anchor_data)
        positive_embedding, _, _ = model(positive_data)
        negative_embedding, _, _ = model(negative_data)
        loss = model.compute_triplet_loss(anchor_embedding, positive_embedding, negative_embedding)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if batch_id % 32 == 0:
            print(f"Epoch [{epoch}], Batch [{batch_id}], Loss: {loss.item():.6f}")
    print(f"Epoch [{epoch}], Total Loss: {total_loss / len(train_loader):.6f}")

# Main Training Loop
dataset_root = './vggface2_preprocessed'
batch_size = 8
epochs = 30
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load Dataset
vgg_dataset = VGGFace2Dataset(root_dir=dataset_root, batch_size=batch_size)
train_loader = vgg_dataset.train_loader

# Initialize CapsNet
config = Config(dataset='vgg')
capsule_net = CapsNet(config)
capsule_net = capsule_net.to(device)

# Define optimizer and training
optimizer = torch.optim.Adam(capsule_net.parameters(), lr=learning_rate)
for epoch in range(1, epochs + 1):
    train_triplet(capsule_net, optimizer, train_loader, epoch, capsule_net.compute_triplet_loss)

# Save Final Model
save_model(capsule_net, filename="final_capsule_net_weights.pth")

  0%|          | 0/20202 [00:00<?, ?it/s]