In [None]:
import os
import matplotlib.pyplot as plt
from transformers import BertTokenizer, BertModel
import torch 
import torch.optim as optim
import pandas as pd
import seaborn as sns
import numpy as np
import re
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix,f1_score
import torch.nn as nn
from eval import predict,PerfectMatch,AGE_TO_INDEX,GENDER_TO_INDEX,load_checkpoint

device = torch.device('cuda' if True and torch.cuda.is_available() else 'cpu')

DATA_DIR =r'C:\Users\vi04wecu\Desktop\Hackbay\processed_data'
test_df = pd.read_excel(os.path.join(DATA_DIR, 'hackbay_test_dataset.xlsx'))
test_labels = pd.read_csv(os.path.join(r'C:\Users\vi04wecu\Desktop\Hackbay', 'test_with_labels.csv'))

MAX_SEQ_LEN = 512

GENDER_TO_INDEX = {
    'maennlich':0,
    'weiblich':1
}
AGE_TO_INDEX = {
'16 bis 17 Jahre':0,
'50 bis 54 Jahre':1,
'65 bis 69 Jahre':2, 
'25 bis 29 Jahre':3,
'14 bis 15 Jahre':4,
'55 bis 59 Jahre':5,
'10 bis 13 Jahre':6,
'75 und mehr Jahre':7,
'60 bis 64 Jahre':8,
'35 bis 39 Jahre':9,
'40 bis 44 Jahre':10,
'70 bis 74 Jahre':11,
'30 bis 34 Jahre':12,
'45 bis 49 Jahre':13,
'18 bis 19 Jahre':14,
'20 bis 24 Jahre':15
}

def preprocess_text(text):
  # preprocess text.
  # remove non-alphanumeric characters
  # keep numbers
  if text == text:
    text = re.sub(r'\W+',' ',text,flags=re.UNICODE)
    text = re.sub(r'[\n\t\r]',' ',text)            # delete linebreakers on windows, linux, mac?
    # trim to required length
    text = text[:MAX_SEQ_LEN]
    return text

test_labels['age'] = test_labels['age'].apply(lambda x:AGE_TO_INDEX[x])
test_labels['gender'] = test_labels['gender'].apply(lambda x:GENDER_TO_INDEX[x])
test_df['text'] =  test_df['title']+ ' '+ test_df['text']

test_df = test_df.dropna(subset=['text'])
print(test_df.shape)  

test_df['text'] = test_df['text'].apply(lambda x: preprocess_text(' '.join(x.split('|'))))
test_df.drop(columns= ['title', 'keywords', 'colors', 'number_of_images'], inplace=True)

 
def load_params():
    model = PerfectMatch().to(device)
    model_path = 'perfect_match_model.pt'
    if os.path.exists(model_path):
        optimizer = optim.Adam(model.parameters(), lr=0.0001)
        model,optimizer = load_checkpoint(model_path, model, optimizer)

    model = model.to(device)
    # Define tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
    
    # Load (and possibly transform) our dataset which will be used for making recommendations
    recommendation_df = pd.read_excel('hackbay_recommendations.xlsx')
    return model,tokenizer,recommendation_df

model, tokenizer = load_params()   

test_df['prediction'] = ''
for i in range(test_df.shape[0]):
        test_df['prediction'][i] = predict( test_df['text'][i], model, tokenizer)
        
        
        
test_pred_out = pd.merge(test_df, test_labels, how= 'inner')

def compute_metrics(labels,probs):
  softmax = nn.Softmax(dim=1)
  preds =softmax(probs)
  acc_preds = torch.argmax(preds,dim=1).squeeze().cpu().tolist()
  labels = labels.squeeze().cpu().tolist()
  acc = accuracy_score(labels,acc_preds)
  f1 = f1_score(labels,acc_preds,average='weighted')
  return {'f1': f1, 'accuracy':acc}