## COVARIATE SHIFT PROBLEM

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


def generate_data(n, mean, p=2):
    """
    mean: shift of the feature distribution
    p: exponent used for label rule (same for train & test)
    """
    X = torch.randn(n, 2) + torch.tensor(mean, dtype=torch.float)
    y = (X[:, 0]**p + X[:, 1]**p > 0).long()
    return X, y


X_train, y_train = generate_data(5000, mean=[-2, -2], p=2)

X_test, y_test = generate_data(2000, mean=[+2, +2], p=2)


model = nn.Linear(2, 2)  # 2 inputs → 2 classes
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.05)


for epoch in range(200):
    optimizer.zero_grad()
    out = model(X_train)
    loss = criterion(out, y_train)
    loss.backward()
    optimizer.step()


def accuracy(model, X, y):
    with torch.no_grad():
        preds = model(X).argmax(dim=1)
    return (preds == y).float().mean().item()

train_acc = accuracy(model, X_train, y_train)
test_acc = accuracy(model, X_test, y_test)

print("Train Accuracy:", train_acc)
print("Test Accuracy:", test_acc)

# we can tell that training and testing accuracy doesn't match causing covariance shift problem.
# what can be done ??????



Train Accuracy: 1.0
Test Accuracy: 0.04899999871850014


SOLUTION 


1) We need to build a classifier (Logistic classifier model)
2) It tries to segregate the test and train datapoints 
3) we assign high weightage to the datapoint which is similar to test data rather than the train datapoint 
4) Finally based on this we update the weights of the model via training and fix this issue 

In [25]:
mean_train = [-2.0, -2.0]
cov_train  = [[1.0, 0.0],
              [0.0, 1.0]]

mean_test = [2.0, 2.0]
cov_test  = [[1.0, 0.0],
             [0.0, 1.0]]

In [32]:

class Logistic_Regression(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(2, 2)  # 2 inputs →
        
    def forward(self, x):
        return self.linear(x)
    
    
model_unweighted = Logistic_Regression()
model_weighted   = Logistic_Regression()

criterion = nn.CrossEntropyLoss(reduction="none")   # important: no reduction
optimizer_unw = optim.Adam(model_unweighted.parameters(), lr=0.05)
optimizer_w   = optim.Adam(model_weighted.parameters(), lr=0.05)
    


In [33]:
train_dist = torch.distributions.MultivariateNormal(
    torch.tensor(mean_train, dtype=torch.float),
    torch.tensor(cov_train, dtype=torch.float)
)
test_dist = torch.distributions.MultivariateNormal(
    torch.tensor(mean_test, dtype=torch.float),
    torch.tensor(cov_test, dtype=torch.float)
)


In [None]:
w = torch.exp(test_dist.log_prob(X_train) - train_dist.log_prob(X_train))

w = w / w.mean() 




In [40]:
def train(model, optimizer, weighted=False):
    for epoch in range(200):
        optimizer.zero_grad()
        
        pred = model(X_train)
        loss = criterion(pred, y_train)
        
        if weighted:
            loss = (w * loss).mean()   # apply importance weights
        else:
            loss = loss.mean()
        
        loss.backward()
        optimizer.step()


train(model_unweighted, optimizer_unw, weighted=False)
train(model_weighted, optimizer_w, weighted=True)

In [None]:
def accuracy(model, X, y):
    with torch.no_grad():
        preds = model(X).argmax(dim=1)
    return (preds == y).float().mean().item()


print("Train Accuracy (no weights):", accuracy(model_unweighted, X_train, y_train))
print("Test Accuracy  (no weights):", accuracy(model_unweighted, X_test, y_test))
print()
print("Train Accuracy (weighted):", accuracy(model_weighted, X_train, y_train))
print("Test Accuracy  (weighted):", accuracy(model_weighted, X_test, y_test))



# Model fixes the training data by adding weightage to datapoints similar to test data and train data is updated and tested 
# So Covariate shift problem is solved.

Train Accuracy (no weights): 1.0
Test Accuracy  (no weights): 0.0794999971985817

Train Accuracy (weighted): 0.510200023651123
Test Accuracy  (weighted): 1.0


### LABEL SHIFT CORRECTION 

In [50]:
import torch
import numpy as np

def generate_data_label_shift(n, mean):
    X = torch.randn(n, 2) + torch.tensor(mean)
    logits = X[:,0] + X[:,1]
    y = (logits > 0).long()
    return X, y

# Train distribution
X_train, y_train = generate_data_label_shift(5000, mean=[0,0])

# Test distribution with changed label probability
X_test, y_test = generate_data_label_shift(5000, mean=[0,0])

# Introduce label shift artificially
# Flip 70% of negatives to positives
mask = (y_test == 0)
flip = torch.rand(mask.sum()) < 0.7
y_test_shifted = y_test.clone()
y_test_shifted[mask] = flip.long()

# Validation distribution (clean, no shift)
X_val, y_val = generate_data_label_shift(2000, mean=[0,0])


In [None]:
import torch.nn as nn
import torch.optim as optim

model = nn.Sequential(
    nn.Linear(2, 2)
)

loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(200):
    opt.zero_grad()
    y_pred = model(X_train)
    loss = loss_fn(y_pred, y_train)
    loss.backward()
    opt.step()


with torch.no_grad():
    pred_train = model(X_train).argmax(1)
    pred_test = model(X_test).argmax(1)

baseline_train_acc = (pred_train == y_train).float().mean().item()
baseline_test_acc = (pred_test == y_test_shifted).float().mean().item()



print("Baseline Train Accuracy:", baseline_train_acc)
print("Baseline Test Accuracy:", baseline_test_acc)

Baseline Train Accuracy: 0.9660000205039978
Baseline Test Accuracy: 0.6489999890327454


In [55]:
X_val, y_val = generate_data_label_shift(2000, mean=[0,0])

with torch.no_grad():
    val_pred = model(X_val).argmax(dim=1)

# Confusion matrix C[i,j] = P(pred=i | true=j)
C = torch.zeros(2,2)

for i in range(2):
    for j in range(2):
        C[i,j] = ((val_pred == i) & (y_val == j)).sum()

# Normalize columns so each column sums to 1
C = C / C.sum(dim=0, keepdim=True)


In [56]:
with torch.no_grad():
    pred_test = model(X_test).argmax(dim=1)

q = torch.stack([(pred_test==0).float().mean(),
                 (pred_test==1).float().mean()])


In [57]:
with torch.no_grad():
    pred_test = model(X_test).argmax(dim=1)

q = torch.stack([(pred_test==0).float().mean(),
                 (pred_test==1).float().mean()])


In [58]:
p_test = torch.linalg.solve(C, q)
p_test = torch.clamp(p_test, min=1e-6)
p_test = p_test / p_test.sum()


p_train = torch.tensor([
    (y_train==0).float().mean(),
    (y_train==1).float().mean()
])

w = p_test / p_train
w


tensor([0.9878, 1.0124])

In [59]:
model_shift = nn.Sequential(nn.Linear(2, 2))
opt = optim.Adam(model_shift.parameters(), lr=0.01)
loss_fn_reduced = nn.CrossEntropyLoss(reduction="none")

for epoch in range(200):
    opt.zero_grad()
    logits = model_shift(X_train)
    sample_weights = w[y_train]
    loss = (sample_weights * loss_fn_reduced(logits, y_train)).mean()
    loss.backward()
    opt.step()

In [60]:
with torch.no_grad():
    pred_test_corrected = model_shift(X_test).argmax(1)

corrected_test_acc = (pred_test_corrected == y_test_shifted).float().mean().item()

In [61]:
print("==============================================")
print(" BASELINE (No Label-Shift Correction)")
print(" Train Accuracy          :", round(baseline_train_acc, 4))
print(" Test Accuracy (shifted) :", round(baseline_test_acc, 4))
print("----------------------------------------------")
print(" ESTIMATED test label distribution p_test:", p_test)
print(" Train label distribution p_train        :", p_train)
print(" Importance weights                      :", w)
print("----------------------------------------------")
print(" AFTER LABEL-SHIFT CORRECTION (BBSE)")
print(" Corrected Test Accuracy :", round(corrected_test_acc, 4))
print("==============================================")

 BASELINE (No Label-Shift Correction)
 Train Accuracy          : 0.966
 Test Accuracy (shifted) : 0.649
----------------------------------------------
 ESTIMATED test label distribution p_test: tensor([0.4963, 0.5037])
 Train label distribution p_train        : tensor([0.5024, 0.4976])
 Importance weights                      : tensor([0.9878, 1.0124])
----------------------------------------------
 AFTER LABEL-SHIFT CORRECTION (BBSE)
 Corrected Test Accuracy : 0.6534


In [None]:
# There is no much difference in test accuracy after label shift correction because the model was already performing well on the shifted data.