In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10
!pip install efficientnet_pytorch
from efficientnet_pytorch import EfficientNet
import numpy as np
import cv2

Collecting efficientnet_pytorch
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: efficientnet_pytorch
  Building wheel for efficientnet_pytorch (setup.py) ... [?25l[?25hdone
  Created wheel for efficientnet_pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16428 sha256=0aeb522e37d46d7c57b4fc93c008003be70aef5ed75d817d9fda3613c8514069
  Stored in directory: /root/.cache/pip/wheels/03/3f/e9/911b1bc46869644912bda90a56bcf7b960f20b5187feea3baf
Successfully built efficientnet_pytorch
Installing collected packages: efficientnet_pytorch
Successfully installed efficientnet_pytorch-0.7.1


To create an image jigsaw solver I first made functions to divide the image into 9 patches ,resize and shuffle them according to the permutations with the number of permutations being 64. And then I used the Cifar10 Dataset for the images and divided into ordered images and shuffeled images by using these functions

In [2]:
def shuffle_patches(image, permutation):
    patches = np.split(image, 9)  # Assuming 3x3 patches
    shuffled_patches = [patches[i] for i in permutation]
    shuffled_image = np.concatenate(shuffled_patches)
    return shuffled_image

def generate_training_pairs(images):
    pairs = []
    for image in images:
        # Divide the image into 3x3 patches
        patches = []
        for i in range(0, image.shape[0], image.shape[0] // 3):
            for j in range(0, image.shape[1], image.shape[1] // 3):
                patch = image[i:i + image.shape[0] // 3, j:j + image.shape[1] // 3]
                patch = cv2.resize(patch, (224, 224))
                patches.append(patch)
        patches = np.array(patches)
        shuffled_image = np.random.permutation(patches)
        # Create training pair (shuffled, ordered)
        pairs.append((shuffled_image, patches))
    return pairs

The ImageJigsawDataset takes the cifar datset and coverts it into pairs of ordered and shuffled images for the network to train and identify.

In [3]:
# Custom Dataset class for the image jigsaw puzzle task with transformations
class ImageJigsawDataset(Dataset):
    def __init__(self, cifar_dataset, transform=None):
        self.cifar_dataset = cifar_dataset
        self.transform = transform
    def __len__(self):
        return len(self.cifar_dataset)
    def __getitem__(self, idx):
        img, _ = self.cifar_dataset[idx]
        if self.transform:
            img = self.transform(img)
        pairs = generate_training_pairs([np.array(img)])

        return pairs


Then I created a context free network using the efficientnet model to train on the ordered and shuffled pairs and added the convolutional layers for more effecient feature extraction as well as reducing the dimensions so that the fully connected layer can be managed easily.

In [4]:
# Context-Free Network (CFN) Model
class CFN(nn.Module):
    def __init__(self, num_permutations, efficientnet_model):
        super(CFN, self).__init__()
        self.efficientnet_model = efficientnet_model
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.fc = nn.Linear(256 * 9, num_permutations)  # Assuming 3x3 patches

    def forward(self, input1, input2):
        output1 = self.efficientnet_model(input1)
        output2 = self.efficientnet_model(input2)
        output1 = self.conv(output1)
        output2 = self.conv(output2)
        output1 = output1.view(output1.size()[0], -1)
        output2 = output2.view(output2.size()[0], -1)
        output = torch.abs(output1 - output2)
        output = self.fc(output)
        return output

Here I have loaded the efficientnet model ad cifar dataset and have applied ImageJigsawDataset on the loaded data so that it can be traied using the model.

In [5]:
# Load the EfficientNet model
efficientnet_model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=10)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Adjust as needed
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Adjust normalization parameters
])

# Create CIFAR-10 dataset for image jigsaw puzzle
cifar_dataset = CIFAR10(root="./data", train=True, download=True, transform=None)
jigsaw_dataset = ImageJigsawDataset(cifar_dataset, transform=transform)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth
100%|██████████| 20.4M/20.4M [00:00<00:00, 48.1MB/s]


Loaded pretrained weights for efficientnet-b0
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 57233903.58it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


Then I divided the data into train, pretrain and test data as specified

In [6]:
dataset_size = len(jigsaw_dataset)
train_size = int(0.05 * dataset_size)
pretrain_size = int(0.45 * dataset_size)
test_size = dataset_size - train_size - pretrain_size
train_dataset, pretrain_dataset, test_dataset = random_split(jigsaw_dataset, [train_size, pretrain_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
pretrain_loader = DataLoader(pretrain_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

I then defined the optimizers and loss criteria and made finctions to pretrain and fine tune the data

In [7]:
model = CFN(num_permutations=64, efficientnet_model=efficientnet_model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def pretrain_cfn_model(model, pretrain_loader, criterion, optimizer, epochs=5):
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for batch in pretrain_loader:
            pairs = batch[0]

            input1 = torch.stack([pair[0] for pair in pairs])
            input2 = torch.stack([pair[1] for pair in pairs])
            target = torch.arange(len(pairs))

            optimizer.zero_grad()
            outputs = model(input1, input2)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Pretrain Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(pretrain_loader)}")

def finetune_cfn_model(model, train_loader, criterion, optimizer, epochs=5):
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for batch in train_loader:
            pairs = batch[0]

            input1 = torch.stack([pair[0] for pair in pairs])
            input2 = torch.stack([pair[1] for pair in pairs])
            target = torch.arange(len(pairs))

            optimizer.zero_grad()
            outputs = model(input1, input2)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Fine-tune Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(train_loader)}")

In [None]:
pretrain_cfn_model(model, pretrain_loader, criterion, optimizer, epochs=5)
finetune_cfn_model(model, train_loader, criterion, optimizer, epochs=5)

def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in test_loader:
            pairs = batch[0]

            input1 = torch.stack([pair[0] for pair in pairs])
            input2 = torch.stack([pair[1] for pair in pairs])
            target = torch.arange(len(pairs))

            outputs = model(input1, input2)
            predicted = torch.argmax(outputs, dim=1)
            total += len(target)
            correct += (predicted == target).sum().item()

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy * 100:.2f}%")


evaluate_model(model, test_loader)