In [1]:
# From a resnet-50 model trained to separate gan/real, I use the embeddings from it to get X number of images from
# real and gan dataset.
# Get mean of real images in this embedding space and get X-nearest images to this mean

In [2]:
## Convert lightning checkpoint into torch checkpoint
import torch

# Path to your PyTorch Lightning checkpoint
lightning_checkpoint_path = '/media/nas2/Aref/share/continual_learning/models/resnet50/checkpoints/epoch=59-step=219420-v_loss=0.1618-v_acc=0.9475.ckpt'

# Load the checkpoint
checkpoint = torch.load(lightning_checkpoint_path)

model_state_dict = checkpoint['state_dict']
# model_state_dict
# PyTorch Lightning prefixes each model parameter key with 'model.'
# You may need to remove this prefix if you're loading the state dict into a plain PyTorch model.
model_state_dict = {key.replace('classifier.', '', 1): value for key, value in model_state_dict.items()}
# model_state_dict

In [3]:
# Standard library imports
import os
from pathlib import Path

# Third-party library imports for numerical operations and image handling
import numpy as np
from PIL import Image

# PyTorch core and neural network imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

# PyTorch utilities for data handling
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset

# torchvision for model architectures, image transformations, and utilities
from torchvision import transforms
from torchvision.transforms import Compose, Resize, ToTensor
from torchvision.io import read_image
from torchvision.models import resnet18, resnet50



model = resnet50(weights=None)
        
# Changing the model architecture just to load the pretrained state_dict
model.fc = nn.Linear(2048, 2)
model.load_state_dict(model_state_dict)

<All keys matched successfully>

In [4]:
# Remove the head
model.fc = nn.Identity()
model.cuda()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
class CustomDataset(Dataset):
    def __init__(self, txt_file_path=None, transform=None):
        """
        Initialize the dataset with either a list of .txt files containing image paths or numpy arrays of images and labels.
        
        Args:
            txt_file_paths (list of str): List of paths to .txt files, each containing image paths. Each .txt file represents a class.
            transform (callable, optional): Optional transform to be applied on a sample.
            images_np (numpy.ndarray, optional): Numpy array of images (used when not using .txt files for image paths).
            labels_np (numpy.ndarray, optional): Numpy array of labels (used when not using .txt files for image paths).
        """
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Load image paths and labels from .txt files
        with open(txt_file_path, 'r') as f:
            for line in f:
                self.image_paths.append(line.strip())  # Remove newline characters
                self.labels.append(0)
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image and label from the list populated from .txt files
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path)  # Assuming these are paths to images

        if image.size[0]<256 or image.size[1]<256:
            resize_transform = transforms.Resize(256)
            image = resize_transform(image)
        
        if image.mode=='L':
            image = image.convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return idx, image, label


In [6]:
transform = transforms.Compose([
        transforms.RandomCrop(256),
        transforms.ToTensor(),
        # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

model.eval()   # Just to make sure the model is in eval mode

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [7]:
txt_file_path_real = '/media/nas2/Aref/share/continual_learning/dataset_file_paths/db-real/train.txt'
train_dataset_real = CustomDataset(txt_file_path=txt_file_path_real, transform=transform)


In [9]:
from tqdm.notebook import tqdm
def get_mean(dataloader, n):
    # Initialize a tensor for accumulating the sum of features
    feature_sum = torch.zeros(2048)
    
    # Keep track of the total number of samples processed
    total_samples, i = 0, 0
    
    # Loop through batches in the dataloader
    for idx, (idx__, images, labels) in enumerate(tqdm(dataloader, desc="Processing")):
    # for idx, images, labels in dataloader:  # Adjusted for the common dataloader output
        i+=1
        images = images.to(device='cuda' if torch.cuda.is_available() else 'cpu')
        with torch.no_grad():
            features = model(images)
        
        # Sum up the features of the current batch
        feature_sum += features.sum(dim=0).cpu()
        
        # Update the total number of samples
        total_samples += images.size(0)

        if i%100==0:
            print(f'{total_samples} processed out of {n}')
    
    # Compute the mean of the features
    feature_mean = feature_sum / total_samples
    
    return feature_mean



In [None]:
txt_file_path_real = '/media/nas2/Aref/share/continual_learning/dataset_file_paths/db-real/train.txt'
train_dataset_real = CustomDataset(txt_file_path=txt_file_path_real, transform=transform)
train_dataloader_real = DataLoader(train_dataset_real, batch_size=64, shuffle=False)
mean_feature_real = get_mean(train_dataloader_real, len(train_dataset_real))

In [17]:
txt_file_path_gan = '/media/nas2/Aref/share/continual_learning/dataset_file_paths/db-gan/train.txt'
train_dataset_gan = CustomDataset(txt_file_path=txt_file_path_gan, transform=transform)
train_dataloader_gan = DataLoader(train_dataset_gan, batch_size=32, shuffle=False)
mean_feature_gan = get_mean(train_dataloader_gan, len(train_dataset_gan))

Processing:   0%|          | 0/3657 [00:00<?, ?it/s]

3200 processed out of 117000
6400 processed out of 117000
9600 processed out of 117000
12800 processed out of 117000
16000 processed out of 117000
19200 processed out of 117000
22400 processed out of 117000
25600 processed out of 117000
28800 processed out of 117000
32000 processed out of 117000
35200 processed out of 117000
38400 processed out of 117000
41600 processed out of 117000
44800 processed out of 117000
48000 processed out of 117000
51200 processed out of 117000
54400 processed out of 117000
57600 processed out of 117000
60800 processed out of 117000
64000 processed out of 117000
67200 processed out of 117000
70400 processed out of 117000
73600 processed out of 117000
76800 processed out of 117000
80000 processed out of 117000
83200 processed out of 117000
86400 processed out of 117000
89600 processed out of 117000
92800 processed out of 117000
96000 processed out of 117000
99200 processed out of 117000
102400 processed out of 117000
105600 processed out of 117000
108800 proc

In [18]:
torch.save(mean_feature_real,'mean_feature_real.pt')
torch.save(mean_feature_gan, 'mean_feature_gan.pt')

In [8]:
mean_feature_real = torch.load('mean_feature_real.pt')
mean_feature_gan = torch.load('mean_feature_gan.pt')

In [18]:
def mse_distance(features, mean_feature):
    # Calculate the MSE distance between features and mean_feature
    distances = torch.sum((features - mean_feature) ** 2, dim=1)
    return distances
    
def find_top_k_images(dataloader, mean_feature, k=1000):
    mean_feature = mean_feature.unsqueeze(0)  # Adjust shape for broadcasting
    top_distances = torch.full((k,), float('inf'))  # Initialize with infinite distances
    top_indices = torch.full((k,), -1, dtype=torch.long)  # Initialize with invalid indices

    # Ensure mean_feature is on the same device as model outputs
    mean_feature = mean_feature.to(device='cuda' if torch.cuda.is_available() else 'cpu')

    current_index = 0  # Keep track of the current global index of images
    # for idx, images, _ in dataloader:
    
    for idx, (idx__, images, labels) in enumerate(tqdm(dataloader, desc="Processing")):
        images = images.to(device='cuda' if torch.cuda.is_available() else 'cpu')
        with torch.no_grad():
            features = model(images).cpu()  # Compute features and move to CPU
        
        distances = mse_distance(features.cpu(), mean_feature.cpu())

        # Combine current distances with top distances and find new top k
        all_distances = torch.cat((top_distances, distances), dim=0)
        top_k_values, top_k_indices = torch.topk(all_distances, k, largest=False, sorted=True)

        # Update top distances and indices
        top_distances = top_k_values
        # Map back to the original indices in the dataset
        original_indices = torch.cat((top_indices, torch.arange(current_index, current_index + len(images))), dim=0)
        top_indices = original_indices[top_k_indices]

        current_index += len(images)  # Update global index

    return top_indices[:k]  # Return the indices of the top k closest images


In [19]:
txt_file_path_real = '/media/nas2/Aref/share/continual_learning/dataset_file_paths/db-real/train.txt'
train_dataset_real = CustomDataset(txt_file_path=txt_file_path_real, transform=transform)
train_dataloader_real = DataLoader(train_dataset_real, batch_size=64, shuffle=False)

real_top_index = find_top_k_images(train_dataloader_real, mean_feature_real)

Processing:   0%|          | 0/1828 [00:00<?, ?it/s]

In [21]:
txt_file_path_gan = '/media/nas2/Aref/share/continual_learning/dataset_file_paths/db-gan/train.txt'
train_dataset_gan = CustomDataset(txt_file_path=txt_file_path_gan, transform=transform)
train_dataloader_gan = DataLoader(train_dataset_gan, batch_size=32, shuffle=False)

gan_top_index = find_top_k_images(train_dataloader_gan, mean_feature_gan)

Processing:   0%|          | 0/3657 [00:00<?, ?it/s]

In [22]:
real_top_index

tensor([ 81139,  38170,  24759,  37895,  91976,  45564,  58680,  50316,  57203,
         37166,  58502,  10682,  91203,  30546,   6219, 114959,  94668,     32,
         65389,  96120,  78079,  75922,  49952,  58371, 113795,  92223,   6527,
         14794,  24566,  23169,  82372,  41198, 109168,  45974, 111339,  37151,
        101345,  66612,  50820,  96152, 102124,  94645,  77990,  88134,  20299,
         98189,  34194,  77795,   8805, 114619,  71320,  72256, 103730,  58825,
         41749,  42605,  31889,  77504,  28798,  42437, 113966,  27863, 112562,
         59186, 103900,  71215,  31056,  47063,  44337,  80494,  93424,  88740,
        100128,  20972,  31127,  33919,  96626, 101296,   2343,  16138,  48898,
         79379,  92665,  49691,  92776,  34144,  99689,  36160,  84403, 104455,
         40846,  16828,  22490,  34063, 114308,  90097,  59909, 114643,  24497,
         72770,  32354,  30037,  18890,    197,  61873,  32965, 100796,  36821,
         13977, 113880,  23447,  57208, 

In [31]:
all_image_path = []
txt_file_path_real = '/media/nas2/Aref/share/continual_learning/dataset_file_paths/db-real/train.txt'
# Load image paths and labels from .txt files
with open(txt_file_path_real, 'r') as f:
    for line in f:
        all_image_path.append(line.strip())  #

real_image_path = [all_image_path[i] for i in real_top_index.tolist()]

In [34]:
all_image_path = []
txt_file_path_gan = '/media/nas2/Aref/share/continual_learning/dataset_file_paths/db-gan/train.txt'
# Load image paths and labels from .txt files
with open(txt_file_path_gan, 'r') as f:
    for line in f:
        all_image_path.append(line.strip())  #

gan_image_path = [all_image_path[i] for i in gan_top_index.tolist()]

In [37]:
real_image_path, gan_image_path

(['/media/nas2/Datasets/COCO/dataset/images/train2017/000000301146.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000245898.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000274784.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000424378.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000264676.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000416450.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000400075.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000519831.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000261824.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000043291.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000569533.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000309598.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train2017/000000178438.jpg',
  '/media/nas2/Datasets/COCO/dataset/images/train20

In [47]:
import csv
    
with open('Continual-Learning-of-Synthetic-images/jupyter-notebooks/data/top_1k_real.txt', 'w') as f:
    for line in real_image_path:
        f.write(f"{line}\n")


In [48]:
import csv

with open('Continual-Learning-of-Synthetic-images/jupyter-notebooks/data/top_1k_gan.txt', 'w') as f:
    for line in gan_image_path:
        f.write(f"{line}\n")

In [45]:
real_image_path

['/media/nas2/Datasets/COCO/dataset/images/train2017/000000301146.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000245898.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000274784.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000424378.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000264676.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000416450.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000400075.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000519831.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000261824.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000043291.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000569533.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000309598.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/000000178438.jpg',
 '/media/nas2/Datasets/COCO/dataset/images/train2017/00000006798

In [50]:
len(real_top_index), len(set(real_top_index))

(1000, 1000)