### Imports

In [1]:
import pandas as pd

### Load data

In [2]:
train_users = pd.read_csv("train/train_users.csv")
train_reviews = pd.read_csv("train/train_reviews.csv")
train_matches = pd.read_csv("train/train_matches.csv")

In [3]:
val_users = pd.read_csv("val/val_users.csv")
val_reviews = pd.read_csv("val/val_reviews.csv")
val_matches = pd.read_csv("val/val_matches.csv")

In [4]:
test_users = pd.read_csv("test/test_users.csv")

In [1]:
import torch
import pandas as pd

# 1. Load the test set DataFrames
test_users = pd.read_csv("./test/test_users.csv")    # Contains user IDs, accommodation IDs, and user features
test_reviews = pd.read_csv("./test/test_reviews.csv")  # Contains review_id, review_positive, review_negative


In [13]:
from tqdm import tqdm

@torch.no_grad()
def predict(model, test_users, test_reviews, top_k=10):
    """
    Predict top-k most similar reviews for each user in test_users.
    """
    model.eval()
    results = []
    
    device = next(model.parameters()).device

    # Wrap test_users with tqdm for progress tracking
    for _, user_row in tqdm(test_users.iterrows(), total=len(test_users), desc="Processing Users"):
        # Exclude columns 'user_id' & 'accommodation_id'
        # Assuming user_row is [user_id, accommodation_id, f1, f2, ..., fN]
        user_features_vals = [
            float(value) if isinstance(value, (int, float)) else 0
            for value in user_row.values[:]
        ]

        user_features = torch.tensor(
            user_features_vals,
            dtype=torch.float32
        ).unsqueeze(0).to(device)

        similarities = []

        # Wrap test_reviews with tqdm if desired for inner loop tracking
        for _, review_row in test_reviews.iterrows():
            review_content = f"{review_row['review_positive']} {review_row['review_negative']}"
            logits = model(user_features, [review_content])
            prob = torch.sigmoid(logits).item()  # Convert logits to probability
            similarities.append((review_row['review_id'], prob))

        # Sort by similarity descending
        similarities.sort(key=lambda x: x[1], reverse=True)
        top_reviews = [review_id for (review_id, _) in similarities[:top_k]]
        results.append((user_row['accommodation_id'], user_row['user_id'], *top_reviews))

    result_df = pd.DataFrame(
        results,
        columns=["accommodation_id", "user_id"] + [f"review_{i}" for i in range(1, top_k + 1)]
    )
    # Add an 'ID' column if it doesn't exist
    if 'ID' not in result_df.columns:
        result_df.insert(0, 'ID', range(1, len(result_df) + 1))

    return result_df


In [16]:
from train import ContrastiveModel

model = ContrastiveModel(13).to("cpu")

test_users_small = test_users[:100]
test_reviews_small = test_reviews[:100]

result_df = predict(model, test_users_small, test_reviews_small)

Processing Users: 100%|██████████| 100/100 [00:24<00:00,  4.06it/s]


In [17]:
result_df

Unnamed: 0,ID,accommodation_id,user_id,review_1,review_2,review_3,review_4,review_5,review_6,review_7,review_8,review_9,review_10
0,1,2086452554,5f83c2ae-d803-4b4c-9d25-1226f90297ce,1c57c05f-d248-4f3e-ba6b-8dd3a53101b4,7f6f5708-9318-4559-973b-0a31e3d0a767,dc86013d-24cc-41b0-bb98-532c9c384ce3,50b111dd-2c4e-4189-a424-1e04b380d782,c28cee37-dcdf-48e0-93bc-93461f94db04,b056caa0-91c9-43e1-a64c-e0dd9e541c49,4f7ae8ba-c241-4376-8ab3-e7131ec046a7,ce4d5d58-156f-4d9a-a732-73894986825d,a357041f-e466-430e-98b4-0a9142011c96,c0c3741d-5f4b-4fc3-9a55-6038f76a8b91
1,2,-202362622,a194a2ef-9487-4cf0-8828-dd5803c8b9d1,775c172c-8199-4e87-adb6-57f9f0d76231,84c975b2-6dfb-4429-bafa-b4fdf48addf4,f1a6af46-bc53-4143-8c40-cf06acc1c4db,8c83e5c3-0067-4eec-a29a-421c6f9b8439,06803a55-6225-4cba-b6d7-af91dc730ee6,ac608237-35b7-41a8-b658-e606c55eaca7,a39a4246-e429-4597-a7f0-df5152b6ce45,16506872-a714-4323-91b0-af7735a2692b,f181394d-35a0-4e27-9cdc-84427c53d321,f894e158-4d2f-41d2-931a-dbd7254bd88a
2,3,-1390928232,cfb878d0-af56-4b0d-90ff-87095b1a56d6,775c172c-8199-4e87-adb6-57f9f0d76231,84c975b2-6dfb-4429-bafa-b4fdf48addf4,f1a6af46-bc53-4143-8c40-cf06acc1c4db,8c83e5c3-0067-4eec-a29a-421c6f9b8439,06803a55-6225-4cba-b6d7-af91dc730ee6,ac608237-35b7-41a8-b658-e606c55eaca7,a39a4246-e429-4597-a7f0-df5152b6ce45,16506872-a714-4323-91b0-af7735a2692b,f181394d-35a0-4e27-9cdc-84427c53d321,f894e158-4d2f-41d2-931a-dbd7254bd88a
3,4,1007230055,19ffcbff-8500-482a-b5af-c55cb4235259,1c57c05f-d248-4f3e-ba6b-8dd3a53101b4,7f6f5708-9318-4559-973b-0a31e3d0a767,dc86013d-24cc-41b0-bb98-532c9c384ce3,50b111dd-2c4e-4189-a424-1e04b380d782,c28cee37-dcdf-48e0-93bc-93461f94db04,b056caa0-91c9-43e1-a64c-e0dd9e541c49,4f7ae8ba-c241-4376-8ab3-e7131ec046a7,ce4d5d58-156f-4d9a-a732-73894986825d,a357041f-e466-430e-98b4-0a9142011c96,c0c3741d-5f4b-4fc3-9a55-6038f76a8b91
4,5,135365139,98d6a06b-131c-464d-86e7-b74dd4894ae2,1c57c05f-d248-4f3e-ba6b-8dd3a53101b4,7f6f5708-9318-4559-973b-0a31e3d0a767,dc86013d-24cc-41b0-bb98-532c9c384ce3,50b111dd-2c4e-4189-a424-1e04b380d782,c28cee37-dcdf-48e0-93bc-93461f94db04,b056caa0-91c9-43e1-a64c-e0dd9e541c49,4f7ae8ba-c241-4376-8ab3-e7131ec046a7,ce4d5d58-156f-4d9a-a732-73894986825d,a357041f-e466-430e-98b4-0a9142011c96,c0c3741d-5f4b-4fc3-9a55-6038f76a8b91
...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,96,307364189,c682c957-032a-49f0-a29d-03580bc2826b,1c57c05f-d248-4f3e-ba6b-8dd3a53101b4,7f6f5708-9318-4559-973b-0a31e3d0a767,dc86013d-24cc-41b0-bb98-532c9c384ce3,50b111dd-2c4e-4189-a424-1e04b380d782,c28cee37-dcdf-48e0-93bc-93461f94db04,b056caa0-91c9-43e1-a64c-e0dd9e541c49,4f7ae8ba-c241-4376-8ab3-e7131ec046a7,ce4d5d58-156f-4d9a-a732-73894986825d,a357041f-e466-430e-98b4-0a9142011c96,c0c3741d-5f4b-4fc3-9a55-6038f76a8b91
96,97,-1336323605,6dc65a28-7c94-493f-8936-eb6d9f9b29c9,775c172c-8199-4e87-adb6-57f9f0d76231,84c975b2-6dfb-4429-bafa-b4fdf48addf4,f1a6af46-bc53-4143-8c40-cf06acc1c4db,8c83e5c3-0067-4eec-a29a-421c6f9b8439,06803a55-6225-4cba-b6d7-af91dc730ee6,ac608237-35b7-41a8-b658-e606c55eaca7,a39a4246-e429-4597-a7f0-df5152b6ce45,16506872-a714-4323-91b0-af7735a2692b,f181394d-35a0-4e27-9cdc-84427c53d321,f894e158-4d2f-41d2-931a-dbd7254bd88a
97,98,2086982600,0ae2a426-0f9c-4423-9115-aac76573b490,1c57c05f-d248-4f3e-ba6b-8dd3a53101b4,7f6f5708-9318-4559-973b-0a31e3d0a767,dc86013d-24cc-41b0-bb98-532c9c384ce3,50b111dd-2c4e-4189-a424-1e04b380d782,c28cee37-dcdf-48e0-93bc-93461f94db04,b056caa0-91c9-43e1-a64c-e0dd9e541c49,4f7ae8ba-c241-4376-8ab3-e7131ec046a7,ce4d5d58-156f-4d9a-a732-73894986825d,a357041f-e466-430e-98b4-0a9142011c96,c0c3741d-5f4b-4fc3-9a55-6038f76a8b91
98,99,-1623214578,085214d9-42b4-4fb7-8624-12130a5bdf2f,775c172c-8199-4e87-adb6-57f9f0d76231,84c975b2-6dfb-4429-bafa-b4fdf48addf4,f1a6af46-bc53-4143-8c40-cf06acc1c4db,8c83e5c3-0067-4eec-a29a-421c6f9b8439,06803a55-6225-4cba-b6d7-af91dc730ee6,ac608237-35b7-41a8-b658-e606c55eaca7,a39a4246-e429-4597-a7f0-df5152b6ce45,16506872-a714-4323-91b0-af7735a2692b,f181394d-35a0-4e27-9cdc-84427c53d321,f894e158-4d2f-41d2-931a-dbd7254bd88a
