In [1]:
# =====================================================
# 🧠 S3 AUGMENTATION NOTEBOOK (with valid NewsID filter)
# =====================================================

# Cell 1: Imports
import torch
from torch import nn
import torch.nn.functional as F
import torch.nn.init as init
import pandas as pd
import numpy as np


In [2]:
# =====================================================
# Cell 2: Define S3 and S3Layer (no logic change)
# =====================================================
class S3(nn.Module):
    def __init__(self, num_layers, initial_num_segments, shuffle_vector_dim=1, segment_multiplier=1,
                 segments_per_layer=None, use_conv_w_avg=True, initialization_type="kaiming", use_stitch=True):
        super(S3, self).__init__()
        self.S3_layers = nn.ModuleList()
        next_segment_num = initial_num_segments
        for i in range(0, num_layers):
            print("Building S3 Layer", i)
            self.S3_layers += [S3Layer(num_segments=next_segment_num,
                                       shuffle_vector_dim=shuffle_vector_dim,
                                       use_conv_w_avg=use_conv_w_avg,
                                       initialization_type=initialization_type,
                                       use_stitch=use_stitch)]
            next_segment_num = int(next_segment_num * segment_multiplier)
            if next_segment_num == 0:
                next_segment_num = 1

    def forward(self, x):
        x_copy = x.clone()
        for S3_layer in self.S3_layers:
            if x_copy.shape[1] >= S3_layer.num_segments:
                sample_num_to_truncate = x_copy.shape[1] % S3_layer.num_segments
                if sample_num_to_truncate > 0:
                    x_copy = x_copy[:, sample_num_to_truncate:, :]
                x_copy = S3_layer(x_copy)
        if x.shape[1] > x_copy.shape[1]:
            x_copy = torch.cat([x[:, 0:x.shape[1] - x_copy.shape[1], :], x_copy], dim=1)
        return x_copy


class S3Layer(nn.Module):
    def __init__(self, num_segments, shuffle_vector_dim=1, use_conv_w_avg=True, initialization_type="kaiming", use_stitch=True):
        super(S3Layer, self).__init__()
        self.num_segments = int(num_segments)
        self.activation = "relu"
        self.use_conv_w_avg = use_conv_w_avg
        self.initialization_type = initialization_type
        self.use_stitch = use_stitch
        self.shuffle_vector_dim = shuffle_vector_dim
        shuffle_vector_shape = tuple([self.num_segments] * self.shuffle_vector_dim)
        self.shuffle_vector = nn.Parameter(torch.empty(shuffle_vector_shape))
        self.initialize_shuffle_vector()
        self.descending_indices = None
        if self.use_conv_w_avg:
            self.w_avg = nn.Conv1d(in_channels=2, out_channels=1, kernel_size=1)
        else:
            self.weights = nn.Parameter(torch.ones(2) * 0.5)

    def initialize_shuffle_vector(self):
        if self.initialization_type == "kaiming" and self.shuffle_vector_dim > 1:
            init.kaiming_normal_(self.shuffle_vector, mode='fan_out', nonlinearity=self.activation)
            scale_factor, shift_value = 0.001, 0.01
            self.shuffle_vector.data.mul_(scale_factor).add_(shift_value)
        elif self.initialization_type == "manual" or self.shuffle_vector_dim == 1:
            scale_factor, shift_value = 0.1, 0.5
            self.shuffle_vector.data.fill_(scale_factor).add_(shift_value)
        else:
            raise ValueError(f"Unsupported initialization type: {self.initialization_type}")

    def forward(self, x):
        if self.shuffle_vector.device != x.device:
            self.shuffle_vector = self.shuffle_vector.to(x.device)
        total_time_steps = x.size(1)
        segments = torch.chunk(x, self.num_segments, dim=1)
        if len(self.shuffle_vector.shape) > 1:
            shuffle_vector_sum = self.shuffle_vector.sum(tuple(range(len(self.shuffle_vector.shape)-1)))
        else:
            shuffle_vector_sum = self.shuffle_vector
        self.descending_indices = torch.argsort(shuffle_vector_sum, descending=True)
        result_matrix = torch.zeros((len(shuffle_vector_sum), len(shuffle_vector_sum)), device=x.device)
        result_matrix.scatter_(1, self.descending_indices.unsqueeze(1), shuffle_vector_sum.unsqueeze(1))
        non_zero_mask = result_matrix != 0
        scaling_factors = 1.0 / torch.abs(result_matrix[non_zero_mask])
        result_matrix[non_zero_mask] *= scaling_factors
        result_matrix = torch.abs(result_matrix)
        stacked_segments = torch.stack(segments, dim=-1)
        stacked_segments = stacked_segments.unsqueeze(-1).expand(-1, -1, -1, -1, stacked_segments.shape[-1])
        multiplication_out = stacked_segments * torch.transpose(result_matrix, 0, 1)
        multiplication_out = torch.transpose(multiplication_out, -2, -1)
        shuffled_segments_stack = multiplication_out.sum(dim=-1)
        shuffled_segments_list = shuffled_segments_stack.unbind(dim=-1)
        concatenated_segments = torch.cat(shuffled_segments_list, dim=1)
        b, t, c = x.shape
        if self.use_stitch:
            stacked_shuffle_original = torch.stack((concatenated_segments, x), dim=-1)
            if self.use_conv_w_avg:
                stacked_shuffle_original_reshaped = stacked_shuffle_original.view(b * t * c, 2, 1)
                out = self.w_avg(stacked_shuffle_original_reshaped)
                out = out.view(b, t, c)
                return out
            else:
                weights_normalized = torch.softmax(self.weights, dim=0)
                out = (stacked_shuffle_original * weights_normalized).sum(dim=-1)
            return out
        else:
            return concatenated_segments


In [52]:
# =====================================================
# Cell 3: Load both datasets
# =====================================================
# User dataset
df_users = pd.read_csv("personalized_test.tsv", sep="\t")

# News reference dataset (to filter valid IDs)
df_news = pd.read_csv("news_min (2).tsv", sep="\t")

# Build valid NewsID set (string form, like "N1", "N2")
valid_news_ids = set(df_news["News ID"].astype(str).tolist())

In [53]:
df_users

Unnamed: 0,userid,clicknewsID,posnewID,rewrite_titles
0,NT1,"N108480,N38238,N35068,N110487,N94904,N72378,N4...","N24110,N62769,N36186,N101669,N19241,N72921,N26...",Legal battle looms over Trump EPA's rule chang...
1,NT2,"N34682,N113236,N119039,N90826,N63278,N27346,N5...","N51765,N37815,N109881,N64357,N13381,N45697,N57...",What You Need to Know About GMOs;;What's Up wi...
2,NT3,"N106204,N74279,N55583,N90083,N117690,N91663,N9...","N96078,N11699,N13028,N36049,N87968,N105007,N11...",Don't Know What's Popular This Summer? We've G...
3,NT4,"N61892,N41396,N42145,N24440,N74099,N73577,N123...","N15817,N104663,N10362,N69465,N16287,N70636,N83...",Summer heat putting your pets at risk;;Trip Ad...
4,NT5,"N79801,N52642,N19270,N112075,N37402,N120660,N3...","N61157,N69119,N101472,N122218,N92462,N67440,N5...",Top News Stories from Texas ;;Some Simple Tips...
...,...,...,...,...
98,NT99,"N74855,N70285,N97607,N14984,N101784,N65808,N28...","N55099,N48939,N85789,N32617,N10476,N23495,N747...",National News Updates ;;Fruit Tea Expansion |...
99,NT100,"N80527,N42741,N32568,N95477,N86762,N77781,N533...","N28172,N64220,N108207,N112458,N108750,N51009,N...",The Blue Jays lead by Richard beat the Royals ...
100,NT101,"N14290,N116936,N110697,N110669,N57257,N94449,N...","N33068,N120666,N85039,N26146,N46240,N122884,N6...",Murdered Father A Hero By Donating Organs To S...
101,NT102,"N101579,N19049,N116697,N106313,N76716,N106985,...","N26153,N93627,N122237,N120408,N105451,N66158,N...",Sacramento State Captiol building to hang Prid...


In [54]:
# =====================================================
# Cell 4: Initialize the S3 model
# =====================================================
s3_model = S3(
    num_layers=20,
    initial_num_segments=5,
    shuffle_vector_dim=3,
    use_conv_w_avg=True,
    initialization_type="kaiming",
    use_stitch=True
)


Building S3 Layer 0
Building S3 Layer 1
Building S3 Layer 2
Building S3 Layer 3
Building S3 Layer 4
Building S3 Layer 5
Building S3 Layer 6
Building S3 Layer 7
Building S3 Layer 8
Building S3 Layer 9
Building S3 Layer 10
Building S3 Layer 11
Building S3 Layer 12
Building S3 Layer 13
Building S3 Layer 14
Building S3 Layer 15
Building S3 Layer 16
Building S3 Layer 17
Building S3 Layer 18
Building S3 Layer 19


In [55]:
# =====================================================
# Cell 5: Helper function for augmentation
# =====================================================
def apply_s3_augmentation(clicknews_ids, model, valid_ids):
    # Keep only valid IDs
    filtered_ids = [id_ for id_ in clicknews_ids if id_ in valid_ids]
    if len(filtered_ids) == 0:
        return clicknews_ids  # if all invalid, return original

    ids_tensor = torch.tensor(filtered_ids, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)
    with torch.no_grad():
        augmented = model(ids_tensor)
    augmented_ids = list(np.round(augmented.squeeze(0).squeeze(-1).cpu().numpy()).astype(int))

    # Re-map to valid IDs again (clip/re-map if out of range)
    final_ids = [id_ if id_ in valid_ids else np.random.choice(list(valid_ids)) for id_ in augmented_ids]
    return final_ids


In [56]:
df_users["ClicknewsID_original"] = df_users["clicknewsID"].copy()

In [57]:
# =====================================================
# Cell 6: Augment clicknewsID column (guaranteed reorder)
# =====================================================
from tqdm import tqdm
import re
import random
import numpy as np

def extract_num(nid):
    return int(re.sub("[^0-9]", "", str(nid))) if re.search(r"\d+", str(nid)) else 0

def build_id(num):
    return f"N{int(num)}"

def force_shuffle(ids, strength=0.3):
    """Randomly shuffle ~strength fraction of segments."""
    n = len(ids)
    k = max(2, int(n * strength))
    idx = np.arange(n)
    swap_idx = np.random.choice(n, k, replace=False)
    np.random.shuffle(swap_idx)
    idx[:k] = swap_idx
    return [ids[i] for i in idx]

valid_nums = sorted([extract_num(nid) for nid in valid_news_ids if extract_num(nid) > 0])
min_id, max_id = min(valid_nums), max(valid_nums)

augmented_clicks = []

for idx, row in tqdm(df_users.iterrows(), total=len(df_users), desc="Applying S3 shuffle"):
    try:
        ids_raw = row['clicknewsID'].replace(",", " ").split() if isinstance(row['clicknewsID'], str) else list(row['clicknewsID'])
        ids = [extract_num(x) for x in ids_raw if extract_num(x) != 0]
        if len(ids) == 0:
            raise ValueError("No valid IDs found")

        # Randomize shuffle vectors per user
        with torch.no_grad():
            for layer in s3_model.S3_layers:
                layer.shuffle_vector.data = torch.rand_like(layer.shuffle_vector)

        # S3 forward pass
        augmented_tensor = apply_s3_augmentation(ids, s3_model, valid_news_ids)
        augmented_array = np.round(np.clip(augmented_tensor, min_id, max_id)).astype(int).tolist()

        # 🔹 FORCE extra segment shuffle if identical
        if augmented_array == ids:
            augmented_array = force_shuffle(augmented_array, strength=0.35)

        new_ids = [build_id(num) for num in augmented_array]
        augmented_clicks.append(" ".join(new_ids))

    except Exception as e:
        print(f"Error at row {idx}: {e}")
        augmented_clicks.append(row['clicknewsID'])

df_users['clicknewsID'] = augmented_clicks


Applying S3 shuffle: 100%|██████████| 103/103 [00:00<00:00, 2385.60it/s]


In [58]:
def normalize_ids(seq):
    # turn "N1,N2 N3" → ['N1','N2','N3']
    if pd.isna(seq): return []
    return [x.strip() for x in str(seq).replace(",", " ").split() if x.strip()]

def same_order(a, b):
    return normalize_ids(a) == normalize_ids(b)

df_users["actually_changed"] = df_users.apply(
    lambda r: not same_order(r["ClicknewsID_original"], r["clicknewsID"]), axis=1
)

print(f"True reordered rows: {df_users['actually_changed'].sum()} / {len(df_users)}")


True reordered rows: 103 / 103


In [59]:
df_users.loc[:10, ["userid", "ClicknewsID_original", "clicknewsID"]]


Unnamed: 0,userid,ClicknewsID_original,clicknewsID
0,NT1,"N108480,N38238,N35068,N110487,N94904,N72378,N4...",N45937 N35350 N113593 N69110 N121447 N94904 N9...
1,NT2,"N34682,N113236,N119039,N90826,N63278,N27346,N5...",N93469 N121528 N13048 N12084 N12292 N19422 N67...
2,NT3,"N106204,N74279,N55583,N90083,N117690,N91663,N9...",N35146 N24142 N26873 N30794 N121539 N48192 N48...
3,NT4,"N61892,N41396,N42145,N24440,N74099,N73577,N123...",N88866 N57373 N16157 N14494 N28237 N70613 N111...
4,NT5,"N79801,N52642,N19270,N112075,N37402,N120660,N3...",N120166 N108942 N123662 N31578 N26134 N52357 N...
5,NT6,"N84892,N48249,N42564,N36344,N29518,N51371,N390...",N34016 N25993 N88722 N45855 N74334 N112019 N18...
6,NT7,"N88143,N80548,N119039,N104554,N96735,N31699,N2...",N31699 N88143 N39254 N93804 N84898 N84230 N330...
7,NT8,"N97687,N33613,N54918,N90798,N102005,N102119,N9...",N57556 N86383 N66806 N44810 N62659 N41186 N106...
8,NT9,"N61197,N23952,N77355,N100360,N74042,N115835,N2...",N38798 N23952 N108595 N59913 N74042 N21503 N10...
9,NT10,"N80815,N123208,N79308,N77296,N109867,N62873,N1...",N36378 N52351 N93678 N39963 N119792 N48111 N94...


In [60]:
df_users

Unnamed: 0,userid,clicknewsID,posnewID,rewrite_titles,ClicknewsID_original,actually_changed
0,NT1,N45937 N35350 N113593 N69110 N121447 N94904 N9...,"N24110,N62769,N36186,N101669,N19241,N72921,N26...",Legal battle looms over Trump EPA's rule chang...,"N108480,N38238,N35068,N110487,N94904,N72378,N4...",True
1,NT2,N93469 N121528 N13048 N12084 N12292 N19422 N67...,"N51765,N37815,N109881,N64357,N13381,N45697,N57...",What You Need to Know About GMOs;;What's Up wi...,"N34682,N113236,N119039,N90826,N63278,N27346,N5...",True
2,NT3,N35146 N24142 N26873 N30794 N121539 N48192 N48...,"N96078,N11699,N13028,N36049,N87968,N105007,N11...",Don't Know What's Popular This Summer? We've G...,"N106204,N74279,N55583,N90083,N117690,N91663,N9...",True
3,NT4,N88866 N57373 N16157 N14494 N28237 N70613 N111...,"N15817,N104663,N10362,N69465,N16287,N70636,N83...",Summer heat putting your pets at risk;;Trip Ad...,"N61892,N41396,N42145,N24440,N74099,N73577,N123...",True
4,NT5,N120166 N108942 N123662 N31578 N26134 N52357 N...,"N61157,N69119,N101472,N122218,N92462,N67440,N5...",Top News Stories from Texas ;;Some Simple Tips...,"N79801,N52642,N19270,N112075,N37402,N120660,N3...",True
...,...,...,...,...,...,...
98,NT99,N59318 N28082 N65808 N32336 N39218 N101784 N36...,"N55099,N48939,N85789,N32617,N10476,N23495,N747...",National News Updates ;;Fruit Tea Expansion |...,"N74855,N70285,N97607,N14984,N101784,N65808,N28...",True
99,NT100,N77781 N57028 N42741 N33174 N123152 N72086 N64...,"N28172,N64220,N108207,N112458,N108750,N51009,N...",The Blue Jays lead by Richard beat the Royals ...,"N80527,N42741,N32568,N95477,N86762,N77781,N533...",True
100,NT101,N38428 N57582 N27350 N20195 N105610 N53706 N43...,"N33068,N120666,N85039,N26146,N46240,N122884,N6...",Murdered Father A Hero By Donating Organs To S...,"N14290,N116936,N110697,N110669,N57257,N94449,N...",True
101,NT102,N73387 N58542 N74476 N31962 N103705 N42570 N16...,"N26153,N93627,N122237,N120408,N105451,N66158,N...",Sacramento State Captiol building to hang Prid...,"N101579,N19049,N116697,N106313,N76716,N106985,...",True


In [61]:
df_users.to_csv("S3_on_test_5_segments.csv")