In [1]:
import os
import struct
import socket
import pickle
import time

from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler

In [2]:
user = 2

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

In [4]:
device = "cpu"

torch.manual_seed(777)
if device =="cuda:0":
    torch.cuda.manual_seed_all(777)

In [5]:
with (open("recsys_data/data.p", "rb")) as openfile:
    train_data = pickle.load(openfile)
with (open("recsys_data/lab_ratings.p", "rb")) as openfile:
    lab_ratings = pickle.load(openfile)
with (open("recsys_data/fb_data_train.p", "rb")) as openfile:
    fb_data_train = pickle.load(openfile)
with (open("recsys_data/fb_labels_train.p", "rb")) as openfile:
    fb_labels_train = pickle.load(openfile)
with (open("recsys_data/fb_data_test.p", 'rb')) as openfile:
    fb_data_test = pickle.load(openfile)
with (open("recsys_data/fb_labels_test.p", "rb")) as openfile:
    fb_labels_test = pickle.load(openfile)

In [6]:
def load_embeddings(embed_file):
    #also normalizes the embeddings
    W = []
    with open(embed_file) as ef:
        for line in ef:
            line = line.rstrip().split()
            vec = np.array(line[1:]).astype(np.float)
            vec = vec / float(np.linalg.norm(vec) + 1e-6)
            W.append(vec)
        #UNK embedding, gaussian randomly initialized 
        print("adding unk embedding")
        vec = np.random.randn(len(W[-1]))
        vec = vec / float(np.linalg.norm(vec) + 1e-6)
        W.append(vec)
    W = np.array(W)
    return W

In [7]:
EMBED_FILE = 'recsys_data/processed_full.embed'

In [8]:
embedding_matrix = load_embeddings(EMBED_FILE)
# embedding_matrix.shape()

adding unk embedding


In [9]:
import pandas as pd
vocab = pd.read_csv('recsys_data/vocab.csv', header=None)
len(set(sorted(vocab[0].tolist())))

51917

In [10]:
def load_vocab_dict(vocab_file):
    vocab_df = pd.read_csv(vocab_file, header=None)
    vocab = sorted(set(vocab_df[0].tolist()))
    ind2w = {i+1:w for i,w in enumerate(vocab)}
    w2ind = {w:i for i,w in ind2w.items()}
    return ind2w, w2ind

In [11]:
idx2w, w2idx = load_vocab_dict('recsys_data/vocab.csv')

In [12]:
def clean_text(text):
    s = text.replace('[', "")
    s = s.replace(']', "")
    s = s.replace("'", "")
    s = s.replace(",", "")
    s = s.split()
    return s

def encoding_disease(data, w2idx):
    idx_total = []
    for i in range(len(data['age'])):
        text = data['disease'][i]
        cleaned_text = clean_text(text)
        idx = []
        for st in cleaned_text:
            if st not in w2idx:
                idx.append(len(w2idx)+1)
            else:
                idx.append(w2idx[st])
        idx_total.append(idx)
    return np.array(idx_total)

In [13]:
def pad_sequences(sequences, max_seq_len: int = 0):
    max_seq_len = max(max_seq_len, max(len(sequence) for sequence in sequences))
    # Pad
    padded_sequences = np.zeros((len(sequences), max_seq_len))
    for i, sequence in enumerate(sequences):
        padded_sequences[i][: len(sequence)] = sequence
    return padded_sequences



In [14]:
train_encoded_disease = encoding_disease(train_data, w2idx)
train_padded_encoded_disease = pad_sequences(train_encoded_disease)
fb_train_encoded_disease = encoding_disease(fb_data_train, w2idx)
fb_train_padded_encoded_disease = pad_sequences(fb_train_encoded_disease)
fb_test_encoded_disease = encoding_disease(fb_data_test, w2idx)
fb_test_padded_encoded_disease = pad_sequences(fb_test_encoded_disease)

train_data['encoded_disease'] = train_padded_encoded_disease
fb_data_train['encoded_disease'] = fb_train_padded_encoded_disease
fb_data_test['encoded_disease'] = fb_test_padded_encoded_disease

  return np.array(idx_total)


In [15]:
train_data['age'] = train_data['age'].reshape((-1, 1))
train_data['weight'] = train_data['weight'].reshape((-1, 1))


In [16]:
scaler_age = MinMaxScaler()
scaler_weight = MinMaxScaler()
train_data['age'] = scaler_age.fit_transform(train_data['age'])
train_data['weight'] = scaler_weight.fit_transform(train_data['weight'])

In [17]:
class InitialDataset(Dataset):
    def __init__(self):
        self.data = train_data
        self.ratings = lab_ratings
    def __getitem__(self, index):
        age = self.data['age'][index]
        weight = self.data['weight'][index]
        icd_codes = self.data['codes'][index]
        disease = self.data['encoded_disease'][index]
        target = self.ratings[index]
        return {
            'age': torch.tensor(age, dtype=float),
            'weight': torch.tensor(weight, dtype=float),
            'disease': torch.tensor(disease, dtype=torch.long),
            'icd_codes': torch.tensor(icd_codes, dtype=float),
            'target': torch.tensor(target, dtype=float)
        }
    def __len__(self):
        return len(self.data['hadm_id'])

In [18]:
num_traindata = len(fb_labels_train) // user
client_order = 0

In [19]:
fb_data_train['age'] = fb_data_train['age'].reshape((-1, 1))
fb_data_train['weight'] = fb_data_train['weight'].reshape((-1, 1))

scaler_age = MinMaxScaler()
scaler_weight = MinMaxScaler()
fb_data_train['age'] = scaler_age.fit_transform(fb_data_train['age'])[num_traindata * client_order : num_traindata * (client_order + 1)]
fb_data_train['weight'] = scaler_weight.fit_transform(fb_data_train['weight'])[num_traindata * client_order : num_traindata * (client_order + 1)]
fb_data_train['codes'] = fb_data_train['codes'][num_traindata * client_order : num_traindata * (client_order + 1)]
fb_labels_train = fb_labels_train[num_traindata * client_order : num_traindata * (client_order + 1)]

In [20]:
fb_data_test['age'] = fb_data_test['age'].reshape((-1, 1))
fb_data_test['weight'] = fb_data_test['weight'].reshape((-1, 1))

scaler_age = MinMaxScaler()
scaler_weight = MinMaxScaler()
fb_data_test['age'] = scaler_age.fit_transform(fb_data_test['age'])
fb_data_test['weight'] = scaler_weight.fit_transform(fb_data_test['weight'])

In [21]:
class FeedbackDataset(Dataset):
    def __init__(self, train=False):
        if train:
            self.data = fb_data_train
            self.ratings = fb_labels_train
        else:
            self.data = fb_data_test
            self.ratings = fb_labels_test
            
    def __getitem__(self, index):
        age = self.data['age'][index]
        weight = self.data['weight'][index]
        icd_codes = self.data['codes'][index]
        disease = self.data['encoded_disease'][index]
        target = self.ratings[index]
        return {
            'age': torch.tensor(age, dtype=float),
            'weight': torch.tensor(weight, dtype=float),
            'disease': torch.tensor(disease, dtype=torch.long),
            'icd_codes': torch.tensor(icd_codes, dtype=float),
            'target': torch.tensor(target, dtype=float)
        }
    def __len__(self):
        return len(self.data['age'])

In [22]:
initial_dataset = InitialDataset()
train_loader = DataLoader(initial_dataset, batch_size=64, shuffle=True)

fb_train_dataset = FeedbackDataset(train=True)
fb_train_loader = DataLoader(fb_train_dataset, batch_size=64, shuffle=True)

fb_test_dataset = FeedbackDataset(train=False)
fb_test_loader = DataLoader(fb_test_dataset, batch_size=64)

In [23]:
total_batch = len(train_loader)
print(total_batch)

181


In [24]:
class RecSysClient(nn.Module):
    def __init__(self):
        super().__init__()
#         W = torch.Tensor(load_embeddings(EMBED_FILE))
#         self.embed = nn.Embedding(W.size()[0], W.size()[1], padding_idx=0)
#         self.embed.weight.data = W.clone()
        W = torch.Tensor(embedding_matrix)
        self.embed = nn.Embedding(W.size()[0], W.size()[1], padding_idx=0)
        self.embed.weight.data = W.clone()
        
        self.fc1 = nn.Linear(152, 1024)
        self.fc2 = nn.Linear(1024, 512)
#         self.fc3 = nn.Linear(512, 256)
#         self.fc4 = nn.Linear(256, 25)
    def forward(self, age, weight, icd_codes, disease):
#         age = x['age']
#         weight = x['weight']
#         icd_codes = x['icd_codes']
        
        embedded = self.embed(disease)
        embedded = torch.mean(embedded, 1)
        
        x = torch.cat((embedded, icd_codes, age, weight), 1).float()
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        return x

In [25]:
client = RecSysClient()

In [26]:
epoch = 1
criterion = nn.MSELoss()
lr = 0.001
optimizer = Adam(client.parameters(), lr=lr)

In [27]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

In [28]:
host = '10.10.7.64'
port = 10080

In [29]:
s = socket.socket()
s.connect((host, port))

In [30]:
epoch = recv_msg(s)   # get epoch
msg = total_batch
send_msg(s, msg) 

In [31]:
for e in range(epoch):
    client_weights = recv_msg(s)
    client.load_state_dict(client_weights)
    client.eval()
    for i, data in enumerate(tqdm(train_loader, ncols=100, desc='Epoch '+str(e+1))):
        age, weight, icd_codes, disease = data['age'], data['weight'], data['icd_codes'], data['disease']
        target = data['target']
        
        age = age.to(device)
        weight = weight.to(device)
        icd_codes = icd_codes.to(device)
        disease = disease.to(device)
        
        target = target.to(device)
        
        optimizer.zero_grad()
        output = client(age, weight, icd_codes, disease)
        client_output = output.clone().detach().requires_grad_(True)
        msg = {
            'client_output': client_output,
            'label': target
        }
        send_msg(s, msg)
        client_grad = recv_msg(s)
        output.backward(client_grad)
        optimizer.step()
    send_msg(s, client.state_dict())
    time.sleep(0.5)       

Epoch 1: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 23.65it/s]
Epoch 2: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 25.18it/s]
Epoch 3: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 24.89it/s]
Epoch 4: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 25.02it/s]
Epoch 5: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 24.36it/s]
Epoch 6: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 24.20it/s]
Epoch 7: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 25.66it/s]
Epoch 8: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 24.71it/s]
Epoch 9: 100%|████████████████████████████████████████████████████| 181/181 [00:07<00:00, 25.54it/s]
Epoch 10: 100%|███████████████████████████████████████████████████| 181/181 [00:07<00:00, 2

In [32]:
# s = socket.socket()
# s.connect((host, port))

In [33]:
total_batch = len(fb_train_loader)
epoch = recv_msg(s)   # get epoch
msg = total_batch
send_msg(s, msg) 

In [34]:
for e in range(epoch):
    client_weights = recv_msg(s)
    client.load_state_dict(client_weights)
    client.eval()
    for i, data in enumerate(tqdm(fb_train_loader, ncols=100, desc='Epoch '+str(e+1))):
        age, weight, icd_codes, disease = data['age'], data['weight'], data['icd_codes'], data['disease']
        target = data['target']
        
        age = age.to(device)
        weight = weight.to(device)
        icd_codes = icd_codes.to(device)
        disease = disease.to(device)
        
        target = target.to(device)
        
        optimizer.zero_grad()
        output = client(age, weight, icd_codes, disease)
        client_output = output.clone().detach().requires_grad_(True)
        msg = {
            'client_output': client_output,
            'label': target
        }
        send_msg(s, msg)
        client_grad = recv_msg(s)
        output.backward(client_grad)
        optimizer.step()
    send_msg(s, client.state_dict())
    time.sleep(0.5)       

Epoch 1: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 26.09it/s]
Epoch 2: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 25.54it/s]
Epoch 3: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 25.94it/s]
Epoch 4: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 26.02it/s]
Epoch 5: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 25.10it/s]
Epoch 6: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 25.55it/s]
Epoch 7: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 25.71it/s]
Epoch 8: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 24.43it/s]
Epoch 9: 100%|██████████████████████████████████████████████████████| 73/73 [00:02<00:00, 25.57it/s]
Epoch 10: 100%|█████████████████████████████████████████████████████| 73/73 [00:02<00:00, 2