In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.io import read_image
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import wandb as wandb



In [3]:
def flatten(x):
    N = x.shape[0] # read in N, C, H, W
    return x.view(N, -1)  # "flatten" the C * H * W values into a single vector per image
class Flatten(nn.Module):
    def forward(self, x):
        return flatten(x)
class Conv_Autoencoder(nn.Module):
    def __init__(self, input_c, channel_1, channel_2, hidden_dim):
        super(Conv_Autoencoder, self).__init__()
        self.channel_2 = channel_2
        self.encoder = nn.Sequential(
                                    nn.Conv2d(3, channel_1, kernel_size=3, stride = 2, padding=1),
                                    nn.ReLU(),
                                    nn.Conv2d(channel_1, channel_2, kernel_size=3, stride = 2, padding=1),
                                    nn.ReLU(),
                                    Flatten(),
                                    nn.Linear(channel_2*8*8, hidden_dim),
                                    nn.ReLU())
        self.linear = nn.Sequential(nn.Linear(hidden_dim, channel_2*8*8))
        self.decoder = nn.Sequential(
                                    nn.Upsample(scale_factor = 2, mode = "bilinear"),
                                    nn.ConvTranspose2d(channel_2, channel_1, kernel_size = 3, stride=1, padding = 1),
                                    nn.ReLU(),
                                    nn.Upsample(scale_factor = 2, mode = "bilinear"),
                                    nn.ConvTranspose2d(channel_1, input_c, kernel_size = 3, stride=1, padding = 1),
                                    nn.Tanh())
    def forward(self, x): 
        hidden_rep = self.encoder(x)
        self.hidden_rep = hidden_rep
        rev_linear = self.linear(self.hidden_rep)
        rev_linear = rev_linear.reshape([rev_linear.shape[0], self.channel_2, 8, 8])
        reconstructed = self.decoder(rev_linear)
        return hidden_rep
    
autoencoder = torch.load('./final_encoder_model.pt')
autoencoder.requires_grad = False

In [75]:
num_workers = 2
batch_size = 128
data_dir = "/home/jupyter"
transforms = v2.Compose([
v2.ToImage(),  # Convert to tensor, only needed if you had a PIL image
v2.ToDtype(torch.uint8),  # optional, most input are already uint8 at this point
v2.ToTensor(),
v2.RandomApply(transforms=[v2.RandomResizedCrop(size=(32, 32), scale = (0.9,0.9),antialias = True),
                               #v2.RandomRotation(degrees=(5,10)),
                               v2.GaussianBlur(kernel_size=(5,5), sigma=1),
                               v2.ColorJitter(brightness=0.5)  
                               #v2.RandomPerspective(p = 1),  #default distortion is 0.5
                               #v2.RandomAdjustSharpness(sharpness_factor = 2, p = 1)  #double the sharpness
                              ], p=0.8),
v2.ConvertImageDtype(torch.float32),
v2.Normalize((0.5,),(0.5,))])
test_dataset = datasets.ImageFolder(root=data_dir+'/test/', transform=transforms)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                        batch_size=10000,
                                        num_workers=num_workers)


num_workers = 2
batch_size = 128
data_dir = "/home/jupyter"
transforms = v2.Compose([
v2.ToImage(),  # Convert to tensor, only needed if you had a PIL image
v2.ToDtype(torch.uint8),  # optional, most input are already uint8 at this point
v2.ToTensor(),
test_dataset = datasets.ImageFolder(root=data_dir+'/test/', transform=transforms)
test_loader_notransform = torch.utils.data.DataLoader(test_dataset,
                                                        batch_size=10000,
                                                        num_workers=num_workers)
autoencoder.eval()
with torch.no_grad():
    for augmented, _ in test_loader: 
        augmented = augmented.to(device='cuda', dtype=torch.float32)
        print(torch.min(augmented))
        print(torch.max(augmented))
        reconstructed_augmented = autoencoder.forward(augmented)
        autoencoder.eval()
with torch.no_grad():
    for original, _ in test_loader_notransform: 
        original = original.to(device='cuda', dtype=torch.float32)
        reconstructed_original = autoencoder.forward(original)