In [38]:
import torch
import torch.nn as nn

class CORAL(nn.Module):
    def forward(self, source_features, target_features):
        d = source_features.size(1)  # Feature dimension

        source_mean = torch.mean(source_features, dim=0)
        target_mean = torch.mean(target_features, dim=0)
        source_cov = (source_features - source_mean).T @ (source_features - source_mean) / (source_features.size(0) - 1)
        target_cov = (target_features - target_mean).T @ (target_features - target_mean) / (target_features.size(0) - 1)

        # Frobenius norm for covariance matrices difference
        coral_loss = torch.norm(source_cov - target_cov, p='fro') / (4*d*d)
        return coral_loss


In [39]:
from setup import setup_src_path
print(setup_src_path())
import data.processed as processed
import config.config as config
import utils.setup as setup
import utils.functions as fn
from importlib import reload



['/home/guest/Desktop/projects/intial-experments/domain_adaptation_project/notebooks', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '', '/home/guest/.cache/pypoetry/virtualenvs/intial-experments-_CPDD38x-py3.8/lib/python3.8/site-packages', '/home/guest/Desktop/projects/intial-experments/domain_adaptation_project/modules', '/tmp/tmptt5ljm60']


In [40]:
from datasets import load_from_disk

source_data=load_from_disk(f"{config.Config.DATASETS_SAVE_PATH}/source_data")
source_data_eval=load_from_disk(f"{config.Config.DATASETS_SAVE_PATH}/source_data_eval")
target_data=load_from_disk(f"{config.Config.DATASETS_SAVE_PATH}/target_data")
test_target_data=load_from_disk(f"{config.Config.DATASETS_SAVE_PATH}/test_target_data")
unsupervised_target=load_from_disk(f"{config.Config.DATASETS_SAVE_PATH}/unsupervised_target")


In [46]:
from torch.utils.data import DataLoader
source_data_loader = DataLoader(source_data, batch_size=64, shuffle=True)
target_data_loader = DataLoader(target_data, batch_size=64)
target_test_loader = DataLoader(test_target_data, batch_size=64)

In [42]:
from adapters import AutoAdapterModel
from tqdm import tqdm

model = AutoAdapterModel.from_pretrained(config.Config.MODEL_NAME)
reload(fn)
fn.print_trainable_parameters(model)


trainable params: 66985530 || all params: 66985530 || trainable%: 100.0


In [47]:
model.delete_adapter(adapter_name)


In [48]:
from transformers import AdamW
adapter_name= "domain-adapter-coral"
model.add_adapter(adapter_name,config="seq_bn")
model.set_active_adapters(adapter_name)
model.train_adapter(adapter_name)

optimizer = AdamW(model.parameters(), lr=1e-4)
coral = CORAL()

In [49]:
fn.print_trainable_parameters(model)


trainable params: 1039392 || all params: 67432794 || trainable%: 1.5413746611181498


In [50]:
def train_epoch(model, source_loader, target_loader, optimizer, coral, device):
    model.train()
    total_coral_loss = 0
    
    # Assuming the same number of batches in source and target loaders
    for source_batch, target_batch in tqdm(zip(source_loader, target_loader)):
        optimizer.zero_grad()
        # Prepare input and move to device
        source_input_ids, source_attention_mask, source_labels = (source_batch["input_ids"].to(device), 
                                                                       source_batch["attention_mask"].to(device), 
                                                                       source_batch["labels"].to(device))
        target_input_ids, target_attention_mask = (target_batch["input_ids"].to(device), 
                                                       target_batch["attention_mask"].to(device))
           
        
        # Forward pass for source
        source_outputs = model(input_ids=source_input_ids, attention_mask=source_attention_mask,  output_hidden_states=True)
        # source_loss = source_outputs.loss
        source_features = source_outputs.hidden_states[-1][:,0,:]  # Use CLS token's representations
        
        # Forward pass for target (without labels)
        with torch.no_grad():
            target_outputs = model(input_ids=target_input_ids, attention_mask=target_attention_mask, output_hidden_states=True,  )
            target_features = target_outputs.hidden_states[-1][:,0,:]  # Use CLS token's representations

        # Compute CORAL loss and combine with source task loss
        coral_loss = coral(source_features, target_features)
        # total_loss = source_loss + coral_loss

        coral_loss.backward()
        optimizer.step()
        
        total_coral_loss += coral_loss.item()
    
    return total_coral_loss / len(source_loader)

# Placeholder for a training epoch call
# Adjust device as per your setup
device = config.Config.DEVICE
model.to(device)
num_epochs = 40
for epoch in range(1, num_epochs+1):
    coral_loss = train_epoch(model, source_data_loader, target_data_loader, optimizer, coral, device)
    print(f"Epoch {epoch}, CORAL Loss: {coral_loss}")


0it [00:00, ?it/s]

235it [01:09,  3.37it/s]


Epoch 1, CORAL Loss: 5.075904123519279e-07


235it [01:10,  3.31it/s]


Epoch 2, CORAL Loss: 4.88355531195669e-07


235it [01:10,  3.32it/s]


Epoch 3, CORAL Loss: 4.748216336238352e-07


235it [01:12,  3.24it/s]


Epoch 4, CORAL Loss: 4.613791394941718e-07


235it [01:13,  3.21it/s]


Epoch 5, CORAL Loss: 4.4717277545751967e-07


235it [01:13,  3.21it/s]


Epoch 6, CORAL Loss: 4.3657416881615063e-07


235it [01:13,  3.21it/s]


Epoch 7, CORAL Loss: 4.286468380553937e-07


235it [01:13,  3.21it/s]


Epoch 8, CORAL Loss: 4.2307764620412667e-07


235it [01:13,  3.21it/s]


Epoch 9, CORAL Loss: 4.1807791734435125e-07


235it [01:13,  3.21it/s]


Epoch 10, CORAL Loss: 4.0823419324384825e-07


235it [01:13,  3.21it/s]


Epoch 11, CORAL Loss: 4.010908711459509e-07


235it [01:13,  3.22it/s]


Epoch 12, CORAL Loss: 3.9616103206925257e-07


235it [01:12,  3.22it/s]


Epoch 13, CORAL Loss: 3.9240797281570983e-07


235it [01:13,  3.22it/s]


Epoch 14, CORAL Loss: 3.8105096040762235e-07


235it [01:13,  3.21it/s]


Epoch 15, CORAL Loss: 3.75621822405332e-07


235it [01:13,  3.21it/s]


Epoch 16, CORAL Loss: 3.7235200432038173e-07


235it [01:13,  3.21it/s]


Epoch 17, CORAL Loss: 3.6285551481030135e-07


235it [01:13,  3.21it/s]


Epoch 18, CORAL Loss: 3.5927000582254864e-07


235it [01:13,  3.21it/s]


Epoch 19, CORAL Loss: 3.4989763142186795e-07


235it [01:13,  3.21it/s]


Epoch 20, CORAL Loss: 3.424234637621685e-07


235it [01:13,  3.21it/s]


Epoch 21, CORAL Loss: 3.337968529192442e-07


235it [01:13,  3.21it/s]


Epoch 22, CORAL Loss: 3.2706632608147874e-07


235it [01:13,  3.22it/s]


Epoch 23, CORAL Loss: 3.110322646947259e-07


235it [01:13,  3.22it/s]


Epoch 24, CORAL Loss: 2.9735076041664625e-07


235it [01:13,  3.21it/s]


Epoch 25, CORAL Loss: 2.821682020079008e-07


235it [01:13,  3.21it/s]


Epoch 26, CORAL Loss: 2.688701613012222e-07


235it [01:13,  3.20it/s]


Epoch 27, CORAL Loss: 2.5916367873002816e-07


235it [01:13,  3.21it/s]


Epoch 28, CORAL Loss: 2.5105611730814743e-07


235it [01:13,  3.21it/s]


Epoch 29, CORAL Loss: 2.4264918403604323e-07


235it [01:13,  3.21it/s]


Epoch 30, CORAL Loss: 2.3639950987716105e-07


235it [01:13,  3.21it/s]


Epoch 31, CORAL Loss: 2.283063135092784e-07


235it [01:13,  3.21it/s]


Epoch 32, CORAL Loss: 2.216354178893509e-07


235it [01:13,  3.22it/s]


Epoch 33, CORAL Loss: 2.1452044699178383e-07


235it [01:12,  3.22it/s]


Epoch 34, CORAL Loss: 2.0918636368847327e-07


235it [01:13,  3.21it/s]


Epoch 35, CORAL Loss: 2.0333353703661152e-07


235it [01:13,  3.21it/s]


Epoch 36, CORAL Loss: 1.9888478232539975e-07


235it [01:13,  3.21it/s]


Epoch 37, CORAL Loss: 1.9253253763360358e-07


235it [01:13,  3.21it/s]


Epoch 38, CORAL Loss: 1.8767432455994053e-07


235it [01:13,  3.21it/s]


Epoch 39, CORAL Loss: 1.8356272120169525e-07


235it [01:13,  3.21it/s]

Epoch 40, CORAL Loss: 1.7894674139037166e-07





In [52]:
model.save_adapter(f"{config.Config.ADAPTER_SAVE_PATH}/{adapter_name}", adapter_name)

In [53]:
model.add_classification_head(
    "task-test-after-coral",
    num_labels=3,
  )

In [54]:
model.adapter_summary(as_dict=True)

[{'name': 'domain-adapter-coral',
  'architecture': 'bottleneck',
  'active': True,
  '#param': 447264,
  'train': True,
  '%param': 0.6739671334336303},
 {'name': 'Full model', '#param': 66362880, '%param': 100.0, 'train': False}]

In [55]:
accuracy_before, f1_before = fn.evaluate_model(model, target_test_loader)
print(f"Accuracy before adaptation: {accuracy_before}")
print(f"F1 score before adaptation: {f1_before}")

Accuracy before adaptation: 0.3426113360323887
F1 score before adaptation: 0.22419396440643205
