In [None]:
from predict import perform_inference
from datasets.rfw_latent import RFW_raw, RFW_latent, create_dataloaders
from train import train, write_model, save_model

import torch
import torchvision.models as models
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F

import sys
sys.path.append('/home/tianqiu/NeuralCompression/lossy-vae')
from lvae import get_model
from lvae.models.qresvae import zoo

In [None]:
DEVICE = 1
device = torch.device(f'cuda:{DEVICE}' if torch.cuda.is_available() else 'cpu')

In [None]:
EPOCHS = 5
LEARNING_RATE = 0.01
RATIO = 0.8
BATCH_SIZE = 32

In [None]:
# load pre-trained qres model
model_name = 'qres17m'
lmb_value = 64
nc_model = get_model(model_name, lmb_value, pretrained=True).to(device) # weights are downloaded automatically

In [None]:
RFW_IMAGES_DIR =  "/media/global_data/fair_neural_compression_data/datasets/RFW/data_64"
RFW_LABELS_DIR = "/media/global_data/fair_neural_compression_data/datasets/RFW/clean_metadata/numerical_labels.csv"
image_ds = RFW_raw(RFW_IMAGES_DIR, RFW_LABELS_DIR)
image_dl_train, image_dl_test = create_dataloaders(image_ds, BATCH_SIZE, RATIO)
latent_ds = RFW_latent(RFW_IMAGES_DIR, RFW_LABELS_DIR, nc_model, device)
latent_dl_train, latent_dl_test = create_dataloaders(latent_ds, BATCH_SIZE, RATIO)

In [None]:
def get_latent(img, nc_model, device):
        ps_layer = nn.PixelShuffle(2)
        img = img.to(device)
        print(img.shape)
        stats_all = nc_model.forward_get_latents(img)
        latents = [stats_all[latent_block_index]['z'] for latent_block_index in range(12)]
        output = torch.cat((F.interpolate(latents[0], 4),latents[1], latents[2]), 1)
        output = torch.cat((F.interpolate(output, 8),latents[3], latents[4],latents[5], latents[6]), 1)
        output = torch.cat((F.interpolate(output, 16),latents[7], latents[8],latents[9], latents[10], latents[11]), 1)
        output = ps_layer(output)
        return output

In [None]:
class MultiHeadResNet(nn.Module):
    def __init__(self, output_dims):
        super(MultiHeadResNet, self).__init__()
        self.dim_reducing_layer = nn.Conv2d(in_channels=19, out_channels=3, kernel_size=1, stride=1, padding=1)
        self.upsampling_layer = nn.Upsample(size=224, mode='nearest')
        self.resnet = models.resnet18(pretrained=True)
        num_features = self.resnet.fc.in_features
        self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1]))
        self.heads = nn.ModuleDict()
        for head, num_classes in output_dims.items():
            self.heads[head] = nn.Linear(num_features, num_classes)

    def forward(self, x):
        x = self.dim_reducing_layer(x)
        x = self.upsampling_layer(x)
        features = self.resnet(x).squeeze()
        outputs = {}
        for head, head_module in self.heads.items():
            output_logits = head_module(features)
            outputs[head] = F.softmax(output_logits, dim=1)
        return outputs

In [None]:
# def train_numerical_rfw(num_epochs, lr, train_loader, device):
num_epochs = 5
lr = 0.01
output_dims = {
    'skin_type': 6,
    'eye_type': 2,
    'nose_type': 2,
    'lip_type': 2,
    'hair_type': 4,
    'hair_color': 5
}

model = MultiHeadResNet(output_dims).to(device)

In [None]:
def train_numerical_rfw(num_epochs, lr, train_loader, device):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    for epoch in range(num_epochs):
        running_loss = 0.0
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for inputs, targets, races in train_loader:
                # print(inputs.shape)
                latents = get_latent(inputs, nc_model, device)
                # inputs, targets = inputs.to(device), targets.to(device)
                targets = targets.to(device)
                optimizer.zero_grad()
                outputs = model(latents)
                loss = 0
                for i, head in enumerate(outputs):
                    loss += criterion(outputs[head], targets[:, i].to(torch.int64))
                loss.backward()
                optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                avg_loss = running_loss / ((pbar.n + 1) * len(latents))  # Compute average loss
                pbar.set_postfix(loss=avg_loss)
                pbar.update(1)
    return model

In [None]:
processing_head = nn.Sequential(
          nn.Conv2d(19,3,3),
          nn.Upsample(size=224, mode='nearest')
        )
random_tensor = torch.randn(32, 19, 16, 16)
out = processing_head(random_tensor)
out.shape

In [None]:
model.parameters

In [None]:
# model = train_numerical_rfw(10, LEARNING_RATE, latent_dl_train, device)
model = train_numerical_rfw(10, LEARNING_RATE, image_dl_train, device)

In [None]:
save_model(model, '../models', 'latent_RFW_numerical_all_labels_resnet18')

In [None]:
import pandas as pd
rfw_labels = pd.read_csv('/media/global_data/fair_neural_compression_data/datasets/RFW/clean_metadata/clean.csv')
rfw_labels.head()

In [None]:
def select_random_images(df, num_images_per_type=6):
    selected_images = []
    for skintype_col in ['skintype_type1', 'skintype_type2', 'skintype_type3', 'skintype_type4', 'skintype_type5', 'skintype_type6']:
        skin_type = skintype_col.split('_')[-1]  # Extract the skin type from the column name
        # Select images for each skin type
        images = df[df[skintype_col] == 1].sample(num_images_per_type)
        selected_images.extend([[skin_type, class_id, image_path] for class_id, image_path in images[['Class_ID', 'File']].values.tolist()])
    return selected_images

In [None]:
import matplotlib.pyplot as plt
random_images = select_random_images(rfw_labels)
print(f'random_images: {random_images[:4]}')
num_rows = 6
num_cols = 6
fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 20))

for i, (skin_type, class_id, image_path) in enumerate(random_images):
    row = i // num_cols
    col = i % num_cols
    img = plt.imread(f'{RFW_IMAGES_DIR}/{class_id}/{image_path}')
    axs[row, col].imshow(img)
    axs[row, col].set_title(f"Skin Type: {skin_type}\nClass ID: {class_id}")
    axs[row, col].axis('off')

plt.tight_layout()
plt.show()

In [None]:
import torch
# model_path = 'models/RFW_numerical_no_skin_resnet18_2024-05-06_16-07-26.pth'
model_path = '../models/RFW_numerical_all_labels_resnet18_2024-05-07_13-45-45.pth'
model = torch.load(model_path)

In [None]:
model_name = 'latent_RFW_numerical_all_labels_resnet18_2024-05-08_23-18-58'

In [None]:
from tqdm import tqdm
import numpy as np

def save_race_based_predictions(
        model, 
        model_name, 
        dataloader, 
        device, 
        prediction_save_dir,
        save_labels=False
    ):
    all_predictions = {'Indian': {head: torch.tensor([]) for head in model.heads.keys()}, 
                       'Caucasian': {head: torch.tensor([]) for head in model.heads.keys()}, 
                       'Asian': {head: torch.tensor([]) for head in model.heads.keys()},  
                       'African': {head: torch.tensor([]) for head in model.heads.keys()}}
    all_labels = {'Indian': {head: torch.tensor([]) for head in model.heads.keys()}, 
                  'Caucasian': {head: torch.tensor([]) for head in model.heads.keys()}, 
                  'Asian': {head: torch.tensor([]) for head in model.heads.keys()}, 
                  'African': {head: torch.tensor([]) for head in model.heads.keys()}}
    
    print(f'prediction_save_dir: {prediction_save_dir}')
    dataloader = tqdm(dataloader, desc="Getting Predictions", unit="batch")
    model.eval()
    with torch.no_grad():
        for _, data in enumerate(dataloader):
            inputs, labels, race = data
            print(inputs.shape)
            race = np.array(race)
            latents = get_latent(inputs, nc_model, device)
            # inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(latents)

            for i, (head, predictions) in enumerate(outputs.items()):
                head_preds = predictions.argmax(dim=1).cpu()

                for race_label in all_labels:
                    race_indices = np.array((race == race_label).nonzero()[0])
                    race_predictions = head_preds[race_indices]
                    race_labels = labels[:, i][race_indices]
                
                    all_predictions[race_label][head] = torch.cat((all_predictions[race_label][head], race_predictions.to('cpu')), dim=0)
                    all_labels[race_label][head] = torch.cat((all_labels[race_label][head], race_labels.to('cpu')), dim=0)

    for race_label in all_labels:
        for category in all_labels[race_label]:
            torch.save(all_predictions[race_label][category], f'{prediction_save_dir}/{model_name}_{race_label}_{category}_predictions.pt')
            if save_labels:
                torch.save(all_labels[race_label][category], f'{prediction_save_dir}/{model_name}_{race_label}_{category}_labels.pt')

    return all_predictions, all_labels


In [None]:
all_predictions, all_labels = save_race_based_predictions(model,
                                                model_name,
                                                image_dl_test, 
                                                device, 
                                                'results/latent_RFW/predictions',
                                            )