In [8]:
import torch
import torch.nn as nn
import torchvision.models as models

def get_pretrained_regression_model(output_size):
    # Load the pretrained ResNet18 model
    pretrained_model = models.resnet18(weights='DEFAULT')
    
    # Modify the last fully connected layer for regression with custom output size
    in_features = pretrained_model.fc.in_features
    pretrained_model.fc = nn.Linear(in_features, output_size)
    
    return pretrained_model

model = get_pretrained_regression_model(100)

In [67]:
from torch.utils.data import Dataset, DataLoader
from sklearn.decomposition import PCA 
from torchvision import transforms
from PIL import Image
import numpy as np

# Define an imaginary dataset class
class ImaginaryDataset(Dataset):
    def __init__(self, num_samples, transform=None, PCA = None):
        self.num_samples = num_samples
        self.transform = transform
        self.PCA = PCA
        self.data, self.output = self.generate_data()

    def generate_data(self):
        data = []
        output_concat = []
        for _ in range(self.num_samples):
            # Generate random image data (3 channels, 128x128 pixels)
            image = np.random.randint(0, 256, size=(3, 128, 128), dtype=np.uint8)
            image = Image.fromarray(np.transpose(image, (1, 2, 0)))

            # Generate random output array of length 1000
            output = np.random.rand(1000)

            data.append((image, output))
            output_concat.append(output)
        if self.PCA: self.PCA.fit(output_concat)
        return data, output_concat
    
    def give_output(self):
        return self.output
    
    def internal_PCA(self):
        if self.PCA:
            return self.PCA
        else:
            print(f'PCA is {self.PCA}')

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image, output = self.data[idx]

        if self.transform:
            image = self.transform(image)
        if self.PCA:
            output = self.PCA.transform(output.reshape(1, -1))
        return image, torch.FloatTensor(output[0])

# Define data transformations (you can customize these as needed)
data_transform = transforms.Compose([
    transforms.ToTensor(),
])


# Create an instance of the ImaginaryDataset
num_samples = 1000  # You can adjust this based on your needs
imaginary_dataset = ImaginaryDataset(num_samples, transform= data_transform)


# Create a DataLoader for batch processing
# batch_size = 32
# data_loader = DataLoader(imaginary_dataset, batch_size=batch_size, shuffle=True)

In [69]:
class ImaginaryDataset(Dataset):
    def __init__(self, images_list, outputs_list, transform=None, PCA=None):
        self.num_samples = len(images_list)
        self.transform = transform
        self.PCA = PCA
        self.data, self.output = self.load_data(images_list, outputs_list)

    def load_data(self, images_list, outputs_list):
        data = []
        output_concat = []

        for i in range(self.num_samples):
            # Load image from the given list
            image = Image.fromarray(images_list[i])

            # Load output array from the given list
            output = outputs_list[i]

            data.append((image, output))
            output_concat.append(output)

        if self.PCA:
            self.PCA.fit(output_concat)

        return data, output_concat

    def give_output(self):
        return self.output

    def internal_PCA(self):
        if self.PCA:
            return self.PCA
        else:
            print(f'PCA is {self.PCA}')

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image, output = self.data[idx]

        if self.transform:
            image = self.transform(image)
        if self.PCA:
            output = self.PCA.transform(output.reshape(1, -1))

        return image, torch.FloatTensor(output[0])

In [68]:
Frozen_PCA = imaginary_dataset.internal_PCA()

PCA is None


In [65]:
len(Frozen_PCA.inverse_transform(imaginary_dataset[0][1]))

1000

In [72]:
from ..utils import yay

ImportError: attempted relative import with no known parent package