In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.utils as vutils
import torchvision.models as models

from transformers import FlavaProcessor, FlavaModel

import numpy as np
import random
import datetime
from PIL import Image
import json

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

device: cuda


In [10]:
class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

class MetallographicDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sample = self.data[idx]

        image_path = sample['image']

        grade_and_process = torch.tensor([float(g) for g in sample['grade']] + [float(p) for p in sample['process']])

        inputs = {
            'image_path': image_path,
            'grade_and_process': grade_and_process,
        }

        labels = torch.tensor(sample['hardness_curve'], dtype=torch.float)

        return {**inputs, "labels": labels}


In [11]:
model = SimpleMLP(input_size=300 * 768, hidden_size=1024, output_size=14).to(device)
model.load_state_dict(torch.load('../../models/hardness/img+grade+process-epoch44-lr001-enhance.pth'))
model.eval()

model_path = "../../models/facebook-flava-full"
processor = FlavaProcessor.from_pretrained(model_path)
flava_model = FlavaModel.from_pretrained(model_path).to(device)

data_path = "../../datasets/data/test/data_cut.json"
with open(data_path, 'r') as f:
    data = json.load(f)
    
test_dataset = MetallographicDataset(data)
random_indices = random.sample(range(len(test_dataset)), 100)
test_subset = torch.utils.data.Subset(test_dataset, random_indices)
test_loader = DataLoader(test_subset, batch_size=1, shuffle=False)

# print(len(test_loader))

In [12]:
total_samples = 0
correct_predictions = 0

total_samples_mae = 0
total_mae = 0.0

ss_res = 0.0
ss_tot = 0.0
all_labels = []

predictions = []
simil_total = []
with torch.no_grad():
    for batch in test_loader:
        image_paths = batch['image_path']
        text_features = batch['grade_and_process'].to(device)
        
        images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
        flava_inputs = processor(text=[str(text) for text in text_features],
                                 images=images,
                                 return_tensors="pt",
                                 padding=True).to(device)

        outputs = flava_model(**flava_inputs)
        multimodal_embeddings = outputs.multimodal_embeddings.to(device)

        desired_length = 300
        if multimodal_embeddings.size(1) > desired_length:
            multimodal_embeddings = multimodal_embeddings[:, :desired_length, :]
        elif multimodal_embeddings.size(1) < desired_length:
            padding_size = desired_length - multimodal_embeddings.size(1)
            padding = torch.zeros(multimodal_embeddings.size(0), padding_size, multimodal_embeddings.size(2)).to(device)
            multimodal_embeddings = torch.cat((multimodal_embeddings, padding), dim=1)

        predictions = model(multimodal_embeddings)
        labels = batch['labels'].to(device)

        simil = 20
        correct = torch.all(torch.abs(predictions - labels) <= simil, dim=1)
        correct_predictions += correct.sum().item()
        total_samples += len(correct)

        mae = torch.abs(predictions - labels).mean().item()
        total_mae += mae
        total_samples_mae += 1

        ss_res += torch.sum((predictions - labels) ** 2).item()
        all_labels.extend(labels.cpu().numpy())


accuracy = correct_predictions / total_samples * 100
print(f"Accuracy: {accuracy:.2f}%")

average_mae = total_mae / total_samples_mae
print(f"Mean Absolute Error (MAE): {average_mae:.4f}")

all_labels = np.array(all_labels)
mean_label = all_labels.mean()
ss_tot = np.sum((all_labels - mean_label) ** 2)
r2 = 1 - (ss_res / ss_tot)
print(f"R^2: {r2:.4f}")

Accuracy: 87.00%
Mean Absolute Error (MAE): 3.3711
R^2: 0.9914
