In [12]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from PIL import Image
import os
import copy
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
from utils.loss_function import SaliencyLoss
from utils.data_process_uni import TrainDataset,ValDataset
from net.models.SUM import SUM
from net.configs.config_setting import setting_config
import sys
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

In [13]:
# Load and preprocess the dataframe
df = pd.read_csv('HandInfo.csv')
age_bins = [0, 21, 22, 23, 24, 31, 76]
labels = np.arange(6)
df['age_category'] = pd.cut(df['age'], bins=age_bins, labels=labels, right=False, include_lowest=True)
df = df[df.accessories == 0]
df['p'] = np.where(df.aspectOfHand.str.startswith('p') == True, 1, 0)
df['r'] = np.where(df.aspectOfHand.str.endswith('right') == True, 1, 0)
df_p_r = df[(df.p == 1) & (df.r == 1)]
df_p_l = df[(df.p == 1) & (df.r == 0)]
df_d_r = df[(df.p == 0) & (df.r == 1)]
df_d_l = df[(df.p == 0) & (df.r == 0)]

In [14]:
# Parameters
image_directory = 'Hands'
batch_size = 16
img_height = 224
img_width = 224
lr = 1e-5
split_size = 0.2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = setting_config
model_cfg = config.model_config
# Training and Validation Loop
best_loss = float('inf')
num_epochs = 80
# Early stopping setup
early_stop_counter = 0
early_stop_threshold = 8


In [15]:

# Define the data augmentation and normalization transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)], p=0.5),
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((img_height, img_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
}


In [16]:
# Split data into training and validation sets
def split_data(df, test_size=split_size):
    train_df, val_df = train_test_split(df, test_size=test_size, stratify=df['id'], random_state=42)
    return train_df, val_df

def encode(col, encoder, val_df, train_df):
    encoder.fit(train_df[col])
    train_df[col] = encoder.transform(train_df[col])
    val_df[col] = encoder.transform(val_df[col])

train_df_p_r, val_df_p_r = split_data(df_p_r)
train_df_p_l, val_df_p_l = split_data(df_p_l)
train_df_d_r, val_df_d_r = split_data(df_d_r)
train_df_d_l, val_df_d_l = split_data(df_d_l)

train_df = pd.concat([train_df_p_r, train_df_d_r, train_df_p_l, train_df_d_l])
val_df = pd.concat([val_df_p_r, val_df_d_r, val_df_p_l, val_df_d_l])
data_len = train_df.shape[0]

pairs = [('id', LabelEncoder()), ('age_category', LabelEncoder()), ('gender', LabelEncoder())]
for pair in pairs:
    encode(pair[0], pair[1], val_df, train_df)
train_id_one_hot = pd.get_dummies(train_df['id'])
val_id_one_hot = pd.get_dummies(val_df['id'])
train_age_one_hot = pd.get_dummies(train_df['age_category'])
val_age_one_hot = pd.get_dummies(val_df['age_category'])
train_gender_one_hot = pd.get_dummies(train_df['gender'])
val_gender_one_hot = pd.get_dummies(val_df['gender'])
num_classes1 = len(pairs[0][1].classes_)
num_classes2 = len(pairs[1][1].classes_)
num_classes3 = len(pairs[2][1].classes_)


In [17]:
class HandDataset(Dataset):
    def __init__(self, dataframe, labels, image_dir, transform=None):
        self.dataframe = dataframe
        self.labels1, self.labels2, self.labels3 = labels
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.dataframe.iloc[idx]['imageName'])
        image = Image.open(img_name).convert('RGB')
        label1 = self.labels1.iloc[idx].values.astype(float)
        label2 = self.labels2.iloc[idx].values.astype(float)
        label3 = self.labels3.iloc[idx].values.astype(float)

        if self.transform:
            image = self.transform(image)

        return image, [torch.tensor(label1), torch.tensor(label2), torch.tensor(label3)]

# Create datasets
train_dataset = HandDataset(train_df, [train_id_one_hot, train_age_one_hot, train_gender_one_hot], image_directory, transform=data_transforms['train'])
val_dataset = HandDataset(val_df, [val_id_one_hot, val_age_one_hot, val_gender_one_hot], image_directory, transform=data_transforms['val'])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)


In [18]:
torch.cuda.empty_cache()

In [19]:
##---------- Prompt Gen Module -----------------------
class PromptGenBlock(nn.Module):
    def __init__(self, prompt_dim=128, prompt_len=5, prompt_size=96, lin_dim=192):
        super(PromptGenBlock, self).__init__()
        self.prompt_param = nn.Parameter(torch.rand(1, prompt_len, prompt_dim, prompt_size, prompt_size))
        self.linear_layer = nn.Linear(lin_dim, prompt_len)
        self.conv3x3 = nn.Conv2d(prompt_dim, prompt_dim, kernel_size=3, stride=1, padding=1, bias=False)

    def forward(self, x):
        B, C, H, W = x.shape
        emb = x.mean(dim=(-2, -1))
        prompt_weights = F.softmax(self.linear_layer(emb), dim=1)
        prompt = prompt_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * self.prompt_param.unsqueeze(0).repeat(B, 1,
                                                                                                                  1, 1,
                                                                                                                  1,
                                                                                                                  1).squeeze(
            1)
        prompt = torch.sum(prompt, dim=1)
        prompt = F.interpolate(prompt, (H, W), mode="bilinear")
        prompt = self.conv3x3(prompt)

        return prompt

In [20]:
class FullyConnectedNetwork(nn.Module):
    def __init__(self, sum_model, num_classes1, num_classes2, num_classes3, input_features, dropout_rate=0.5):
        super(FullyConnectedNetwork, self).__init__()
        self.sum_model = sum_model
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(dropout_rate)
        
        # Prompt blocks for each head
        self.prompt_block1 = PromptGenBlock(prompt_dim=768, prompt_len=5, prompt_size=8, lin_dim=768)
        self.prompt_block2 = PromptGenBlock(prompt_dim=768, prompt_len=5, prompt_size=8, lin_dim=768)
        self.prompt_block3 = PromptGenBlock(prompt_dim=768, prompt_len=5, prompt_size=8, lin_dim=768)
        self.adjust_conv_layer = nn.Conv2d(768 * 2, 768, kernel_size=1)
        
        # first fully connected layers
        self.fc1_1 = nn.Linear(input_features, 512)
        self.bn1_1 = nn.BatchNorm1d(512)
        self.fc2_1 = nn.Linear(512, 256)
        self.bn2_1 = nn.BatchNorm1d(256)
        self.fc3_1 = nn.Linear(256, num_classes1)
        # second fully connected layers
        self.fc1_2 = nn.Linear(input_features, 256)
        self.bn1_2 = nn.BatchNorm1d(256)
        self.fc2_2 = nn.Linear(256, 128)
        self.bn2_2 = nn.BatchNorm1d(128)
        self.fc3_2 = nn.Linear(128, num_classes2)
        # third fully connected layers
        self.fc1_3 = nn.Linear(input_features, 128)
        self.bn1_3 = nn.BatchNorm1d(128)
        self.fc2_3 = nn.Linear(128, 64)
        self.bn2_3 = nn.BatchNorm1d(64)
        self.fc3_3 = nn.Linear(64, num_classes3)
        
    def forward(self, x):
        x = self.sum_model(x) # torch.Size([B, 7, 7, 768])
        
        ### APPLY Prompt for first head
        prompt_input1 = x.permute(0, 3, 1, 2)  # torch.Size([B ,768, 7, 7])
        prompt_output1 = self.prompt_block1(prompt_input1) # torch.Size([B ,768, 7, 7])
        prompt_output1 = torch.cat([prompt_input1, prompt_output1], dim=1) # torch.Size([B, 1536, 7, 7])
        prompt_output1 = self.adjust_conv_layer(prompt_output1) # torch.Size([B, 768, 7, 7])
        x1 = prompt_output1.permute(0, 2, 3, 1) # torch.Size([B, 7, 7, 768])
        x1 = self.flatten(x1)
        
        ### APPLY Prompt for second head
        prompt_input2 = x.permute(0, 3, 1, 2)  # torch.Size([B ,768, 7, 7])
        prompt_output2 = self.prompt_block2(prompt_input2) # torch.Size([B ,768, 7, 7])
        prompt_output2 = torch.cat([prompt_input2, prompt_output2], dim=1) # torch.Size([B, 1536, 7, 7])
        prompt_output2 = self.adjust_conv_layer(prompt_output2) # torch.Size([B, 768, 7, 7])
        x2 = prompt_output2.permute(0, 2, 3, 1) # torch.Size([B, 7, 7, 768])
        x2 = self.flatten(x2)
        
        ### APPLY Prompt for third head
        prompt_input3 = x.permute(0, 3, 1, 2)  # torch.Size([B ,768, 7, 7])
        prompt_output3 = self.prompt_block3(prompt_input3) # torch.Size([B ,768, 7, 7])
        prompt_output3 = torch.cat([prompt_input3, prompt_output3], dim=1) # torch.Size([B, 1536, 7, 7])
        prompt_output3 = self.adjust_conv_layer(prompt_output3) # torch.Size([B, 768, 7, 7])
        x3 = prompt_output3.permute(0, 2, 3, 1) # torch.Size([B, 7, 7, 768])
        x3 = self.flatten(x3)
        
        # Forward through the first separate fully connected network
        out1 = torch.relu(self.bn1_1(self.fc1_1(x1)))
        out1 = self.dropout(out1)
        out1 = torch.relu(self.bn2_1(self.fc2_1(out1)))
        out1 = self.dropout(out1)
        out1 = self.fc3_1(out1)
        
        # Forward through the second separate fully connected network
        out2 = torch.relu(self.bn1_2(self.fc1_2(x2)))
        out2 = self.dropout(out2)
        out2 = torch.relu(self.bn2_2(self.fc2_2(out2)))
        out2 = self.dropout(out2)
        out2 = self.fc3_2(out2)
        
        # Forward through the third separate fully connected network
        out3 = torch.relu(self.bn1_3(self.fc1_3(x3)))
        out3 = self.dropout(out3)
        out3 = torch.relu(self.bn2_3(self.fc2_3(out3)))
        out3 = self.dropout(out3)
        out3 = self.fc3_3(out3)
        
        return out1, out2, out3
    
    # Assuming the SUM class is already defined as per your model
if config.network == 'sum':
    sum_model = SUM(
        num_classes=model_cfg['num_classes'],
        input_channels=model_cfg['input_channels'],
        depths=model_cfg['depths'],
        depths_decoder=model_cfg['depths_decoder'],
        drop_path_rate=model_cfg['drop_path_rate'],
        load_ckpt_path=model_cfg['load_ckpt_path'],
    )
    sum_model.load_from()
    sum_model.cuda()

    
input_features = 7 * 7 * 768  # Adjust this value based on your model's output shape
model = FullyConnectedNetwork(sum_model, num_classes1, num_classes2, num_classes3, input_features)
model.cuda()    

# Set up criterion, optimizer, and learning rate scheduler
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Use CrossEntropyLoss with label smoothing
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

# Initialize TensorBoard SummaryWriter
writer = SummaryWriter('runs/experiment_prompt2')

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25, patience=10, base_model_path="model_prompt.pth"):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
                data_loader = train_loader
            else:
                model.eval()   # Set model to evaluate mode
            #     data_loader = val_loader

            # running_loss = 0.0
            # running_corrects1 = 0
            # running_corrects2 = 0
            # running_corrects3 = 0

            # # Iterate over data
            # for batch_idx, (inputs, labels) in enumerate(data_loader):
            #     inputs = inputs.to(device)
            #     labels1, labels2, labels3 = labels
            #     labels1 = labels1.to(device)
            #     labels2 = labels2.to(device)
            #     labels3 = labels3.to(device)

            #     # Zero the parameter gradients
            #     optimizer.zero_grad()

            #     # Forward
            #     with torch.set_grad_enabled(phase == 'train'):
            #         outputs1, outputs2, outputs3 = model(inputs)
            #         _, preds1 = torch.max(outputs1, 1)
            #         _, preds2 = torch.max(outputs2, 1)
            #         _, preds3 = torch.max(outputs3, 1)
            #         loss1 = criterion(outputs1, labels1.argmax(dim=1))
            #         loss2 = criterion(outputs2, labels2.argmax(dim=1))
            #         loss3 = criterion(outputs3, labels3.argmax(dim=1))
            #         loss = loss1 + loss2 + loss3

            #         # Backward + optimize only if in training phase
            #         if phase == 'train':
            #             loss.backward()
            #             optimizer.step()

            #     # Statistics
            #     running_loss += loss.item() * inputs.size(0)
            #     running_corrects1 += torch.sum(preds1 == labels1.argmax(dim=1).data)
            #     running_corrects2 += torch.sum(preds2 == labels2.argmax(dim=1).data)
            #     running_corrects3 += torch.sum(preds3 == labels3.argmax(dim=1).data)

            #     # Print loss status after each batch
            #     if phase == 'train':
            #         sys.stdout.write(f'\rBatch {batch_idx}/{len(data_loader) - 1} Loss: {loss.item():.4f}')
            #         sys.stdout.flush()

            # epoch_loss = running_loss / len(data_loader.dataset)
            # epoch_acc1 = running_corrects1.double() / len(data_loader.dataset)
            # epoch_acc2 = running_corrects2.double() / len(data_loader.dataset)
            # epoch_acc3 = running_corrects3.double() / len(data_loader.dataset)

            # # Log to TensorBoard
            # writer.add_scalar(f'{phase}/Loss', epoch_loss, epoch)
            # writer.add_scalar(f'{phase}/Accuracy1', epoch_acc1, epoch)
            # writer.add_scalar(f'{phase}/Accuracy2', epoch_acc2, epoch)
            # writer.add_scalar(f'{phase}/Accuracy3', epoch_acc3, epoch)

            # print(f'{phase} Loss: {epoch_loss:.4f} Acc1: {epoch_acc1:.4f} Acc2: {epoch_acc2:.4f} Acc3: {epoch_acc3:.4f}')

    #         # Deep copy the model
    #         if phase == 'val':
    #             scheduler.step(epoch_loss)
    #             avg_acc = (epoch_acc1 + epoch_acc2 + epoch_acc3) / 3
    #             if avg_acc > best_acc:
    #                 best_acc = avg_acc
    #                 best_model_wts = copy.deepcopy(model.state_dict())
    #                 epochs_no_improve = 0
    #                 torch.save(model.state_dict(), base_model_path)  # Save the base model
    #                 print(f"Best model save at epoch {epoch}")
    #             else:
    #                 epochs_no_improve += 1

    #             if epochs_no_improve >= patience:
    #                 print(f'Early stopping at epoch {epoch}')
    #                 model.load_state_dict(best_model_wts)
    #                 writer.close()
    #                 return model

    #     print()

    # print(f'Best val Acc: {best_acc:.4f}')

    # # Load best model weights
    # model.load_state_dict(best_model_wts)
    # writer.close()
    # return model

In [21]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Train the model
#model = train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=num_epochs, patience=5)

In [22]:
model.load_state_dict(torch.load("./net/pre_trained_weights/model_prompt.pth", map_location="cuda"))

<All keys matched successfully>