In [None]:
%reset

In [None]:
import os
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
os.chdir('/home/adam/adam/causal_inference')

from causal_inference.model.cfr import UseCase
from causal_inference.model.metrics import MMDLoss

In [None]:
# set the path

os.chdir('/home/adam/adam/data/19012021/')

# load the dataset
dataset = UseCase('data_guerin_rct.csv',
                  'pf_ratio_2h_8h_outcome',
                  'treated',
                  seed=1234)

# calculate split
train, test = dataset.get_splits()

# prepare data loaders
batch_size = 64
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test, batch_size=batch_size, shuffle=False)



In [None]:
n_of_features = 27
hidden_layer_1 = 5
hidden_layer_2 = 5
representation_layer = 1


In [None]:
class RepresentationNetwork(nn.Module):
    def __init__(self):
        super(RepresentationNetwork, self).__init__()

        self.rep1 = nn.Linear(26, 10 * 26)
        self.rep2 = nn.Linear(10 * 26, 10 * 26)
        self.rep3 = nn.Linear(10 * 26, 26)

    def forward(self, x):
        t = x[:, 0]
        t = torch.reshape(t, (t.shape[0], 1))
        x = x[:, 1:]
        x = nn.functional.elu(self.rep1(x))
        x = nn.functional.elu(self.rep2(x))
        x = self.rep3(x)
        return torch.cat((t, x), dim=1)

representation_model = RepresentationNetwork()

class FactualModel(nn.Module):
    def __init__(self):
        super(FactualModel, self).__init__()

        self.factual1 = nn.Linear(26, 10 * 26)
        self.factual2 = nn.Linear(10 * 26, 10 * 26)
        self.factual3 = nn.Linear(10 * 26, 1)

    def forward(self, x):
        x = nn.functional.elu(self.factual1(x))
        x = nn.functional.elu(self.factual2(x))
        x = self.factual3(x)
        return x

control_model = FactualModel()
treated_model = FactualModel()

In [None]:
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(27, 4 * 27)
        self.l2 = nn.Linear(4 * 27, 4 * 27)
        self.do1 = nn.Dropout(0.1)
        self.l3 = nn.Linear(4 * 27, 4 * 27)
        self.l4 = nn.Linear(4 * 27, 4 * 27)
        self.do2 = nn.Dropout(0.1)
        self.l5 = nn.Linear(4 * 27, 1)
    def forward(self, x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        do1 = self.do1(h1 + h2)
        h3 = nn.functional.relu(self.l3(do1))
        h4 = nn.functional.relu(self.l2(h3))
        do2 = self.do2(h3 + h4)
        prediction = self.l3(do2)
        return prediction

model = ResNet()

In [None]:
representation_params = representation_model.parameters()
control_params = control_model.parameters()
treated_params = treated_model.parameters()

lr = 0.005

optimizer = optim.Adam(params, lr=lr)

In [None]:
#loss = nn.MSELoss(reduce='mean')

In [None]:
nb_epochs = 20

for epoch in range(nb_epochs):
    losses = list()
    model.train()
    for batch in train_loader:
        x, y = batch
        x, y = x.float(), y.float()

        # temporary
        x[x != x] = 0

        # x has always t in the first column

        # 1. forward
        representation = model(x.float())

        # 2. objective function
        objective = loss(representation, y.float())

        # 3. cleaning the gradients
        model.zero_grad()

        # 4. compute gradients
        objective.backward()

        # 5. update weights
        optimizer.step()
        # manual grad update
        # with torch.no_grad(): params = params - eta * params.grad

        losses.append(objective.item())

    print(f'Epoch {epoch + 1}, train loss: {torch.tensor(losses).mean():.2f}')

    losses = list()
    model.eval()
    for batch in test_loader:
        x, y = batch
        x, y = x.float(), y.float()

        # x has always t in the first column

        # temporary
        x[x != x] = 0

        # 1. forward
        with torch.no_grad():
            representation = model(x)

        # 2. objective function
        objective = loss(representation, y)



        losses.append(objective.item())

    print(f'Epoch {epoch + 1}, valid loss: {torch.tensor(losses).mean():.2f}')

In [None]:
from causal_inference.model.metrics import mmd_loss


In [None]:
representation_loss = MMDLoss(kernel='multiscale')
factual_loss = nn.MSELoss(reduce='mean')

In [None]:
nb_epochs = 1
for epoch in range(nb_epochs):
    losses = list()
    model.train()
    for batch in train_loader:
        # x has always t in the first column
        x, y = batch
        x, y = x.float(), y.float()

        # temporary
        x[x != x] = 0

        # 1. Representation Training
        representation = representation_model(x.float())
        representation_objective = representation_loss(representation, y)
        representation_model.zero_grad()
        representation_objective.backward()

        x_control = x[x[:, 0] == 0, 1:]
        y_control = y[x[:, 0] == 0]
        x_treated = x[x[:, 0] == 1, 1:]
        y_treated = y[x[:, 0] == 1]

        nb_control = x_control.shape[0]
        nb_treated = x_treated.shape[0]

        if nb_control > 0:
            output = control_model(x_control.float())
            control_objective = factual_loss(output, y_control)
            control_model.zero_grad()
            control_objective.backward()

        if nb_treated > 0:
            output = treated_model(x_treated.float())
            treated_objective = factual_loss(output, y_treated)
            treated_model.zero_grad()
            treated_objective.backward()

        with torch.no_grad():
            #print(representation_model.parameters())
            #print(representation_objective.grad)
            for p in model.parameters():
                print(p.data)
                print(lr)
                print(representation_objective.grad.data)
                #p.data -= learning_rate * param.grad.data
             #   new_val = update_function(p, p.grad, loss, other_params)
              #  p.copy_(new_val)

        #param - learning_rate * grad
        # manual grad update
        # with torch.no_grad(): params = params - eta * params.grad

        losses.append(objective.item())

    print(f'Epoch {epoch + 1}, train loss: {torch.tensor(losses).mean():.2f}')

    losses = list()
    model.eval()
    for batch in test_loader:
        x, y = batch
        x, y = x.float(), y.float()

        # x has always t in the first column

        # temporary
        x[x != x] = 0

        # 1. forward
        with torch.no_grad():
            representation = model(x)

        # 2. objective function
        objective = loss(representation, y)

        losses.append(objective.item())

    print(f'Epoch {epoch + 1}, valid loss: {torch.tensor(losses).mean():.2f}')

In [None]:
objective

In [None]:
for param in model.parameters():
        param.data -= learning_rate * param.grad.data
