In [None]:
import torch
import torchvision
from torch import nn
from transformers import AutoTokenizer, CLIPModel, CLIPVisionModelWithProjection
from train_c_lora import WurstCore
from train_b import WurstCore as WurstCoreB
from warp_core.utils import load_or_fail
import yaml
import matplotlib.pyplot as plt
from PIL import Image
import requests

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
# SETUP WARPCORE
config_file = 'configs/finetune_c_3b_lora.yml'
with open(config_file, "r", encoding="utf-8") as file:
    loaded_config = yaml.safe_load(file)
    loaded_config['use_fsdp'] = False
    loaded_config['batch_size'] = 4

warpcore = WurstCore(
    config_dict=loaded_config,
    device=device
)

# STAGE B
config_file_b = 'configs/finetune_b_3b.yml'
with open(config_file_b, "r", encoding="utf-8") as file:
    config_file_b = yaml.safe_load(file)
    config_file_b['use_fsdp'] = False
    config_file_b['batch_size'] = 4
    
warpcore_b = WurstCoreB(
    config_dict=config_file_b,
    device=device
)

In [115]:
# SETUP MODELS & DATA
extras = warpcore.setup_extras_pre()
data = warpcore.setup_data(extras)
# models = warpcore.setup_models(extras)
# models.generator.bfloat16().eval().requires_grad_(False)
print("CONTROLNET READY")

extras_b = warpcore_b.setup_extras_pre()
models_b = warpcore_b.setup_models(extras_b)
# models_b.generator.bfloat16().eval().requires_grad_(False)
print("STAGE B READY")
pass

In [116]:
lora_modules = nn.ModuleList()
for module in model.modules():
    if isinstance(module, LoRA) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, LoRA)):
        lora_modules.append(module)      
print(len(lora_modules), lora_modules)

0 ModuleList()


# CLIP STUFF

In [None]:
class ReToken(nn.Module):
    def __init__(self, indices=None):
        super().__init__()
        assert indices is not None
        self.embeddings = nn.Parameter(torch.zeros(len(indices), 1280))
        self.register_buffer('indices', torch.tensor(indices))

    def forward(self, embeddings):
        for i, idx in enumerate(self.indices):
            embeddings[idx] += self.embeddings[i]
        return embeddings
    
def apply_retoken(module, indices=None):
    def check_parameter(module, name):
        return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(getattr(module, name), nn.Parameter)

    if check_parameter(module, "weight"):
        torch.nn.utils.parametrize.register_parametrization(module, 'weight', ReToken(indices=indices))
        

In [197]:
clip_tokenizer = AutoTokenizer.from_pretrained('laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')
clip_model = CLIPModel.from_pretrained('laion/CLIP-ViT-bigG-14-laion2B-39B-b160k')
clip_text_model = clip_model.text_model.to('cpu').eval().requires_grad_(False)

print(len(clip_tokenizer.vocab))
clip_tokenizer.add_tokens(["[snail]"])
print(len(clip_tokenizer.vocab))

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

49408
49409


In [198]:
tokens = clip_tokenizer(['the snail is small'], truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt")
print(tokens['input_ids'])

tokens = clip_tokenizer(['the snails are small'], truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt")
print(tokens['input_ids'])

tokens = clip_tokenizer(['the [snail] is small'], truncation=True, padding="max_length", max_length=clip_tokenizer.model_max_length, return_tensors="pt")
print(tokens['input_ids'])


tensor([[49406,   518, 23132,   533,  2442, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407]])
tensor([[49406,   518, 34928,   631,  2442, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
         49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,

In [199]:
update_indices = [v for k, v in clip_tokenizer.vocab.items() if re.search("^snail", k) is not None]
update_indices += [len(clip_tokenizer.vocab)-1]
print(update_indices)

clip_text_model.embeddings.token_embedding.weight.data = torch.cat([
    clip_text_model.embeddings.token_embedding.weight.data, 
    torch.zeros_like(clip_text_model.embeddings.token_embedding.weight.data)[:1]
], dim=0)

apply_retoken(clip_text_model.embeddings.token_embedding, update_indices)

[34928, 23132, 49408]
>> 0 tensor(34928)
>> 1 tensor(23132)
>> 2 tensor(49408)


In [200]:
clip_text_model.embeddings.token_embedding(torch.randint(0, len(clip_tokenizer.vocab), size=(100,)))

>> 0 tensor(34928)
>> 1 tensor(23132)
>> 2 tensor(49408)


tensor([[-0.0071,  0.0082, -0.0185,  ..., -0.0069, -0.0014,  0.0088],
        [ 0.0186, -0.0301,  0.0130,  ...,  0.0110, -0.0009, -0.0193],
        [-0.0017,  0.0051,  0.0044,  ...,  0.0016,  0.0092, -0.0060],
        ...,
        [ 0.0066, -0.0018,  0.0110,  ...,  0.0049, -0.0225,  0.0067],
        [-0.0001,  0.0223, -0.0089,  ...,  0.0011, -0.0195, -0.0008],
        [ 0.0013,  0.0120, -0.0106,  ..., -0.0001, -0.0024,  0.0111]],
       grad_fn=<EmbeddingBackward0>)

In [201]:
clip_text_model.embeddings.token_embedding.parametrizations.weight[0]

ReToken()

In [203]:
trainable_modules = nn.ModuleDict()
trainable_modules['retokens'] = nn.ModuleList()

for module in clip_text_model.embeddings.token_embedding.modules():
    if isinstance(module, ReToken) or (hasattr(module, '_fsdp_wrapped_module') and isinstance(module._fsdp_wrapped_module, ReToken)):
        trainable_modules['retokens'].append(module)
print(len(trainable_modules['retokens']), trainable_modules['retokens'])


1 ModuleList(
  (0): ReToken()
)


In [206]:
clip_text_model.embeddings.token_embedding.parametrizations.weight[0]

ReToken()

In [242]:
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.distributed import init_process_group, destroy_process_group
import functools

model = nn.Sequential(
    nn.Linear(64, 64),
    nn.Linear(64, 64),
    nn.Linear(64, 64),
)

submodules = nn.ModuleList([model[0]])

In [244]:
dist_file_path = f"{os.getcwd()}/dist_file_test"
init_process_group(
    backend="nccl",
    rank=0,
    world_size=1,
    init_method=f"file://{dist_file_path}",
)

# fsdp_auto_wrap_policy = ModuleWrapPolicy([nn.Linear])    
fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000)
submodules = FSDP(submodules, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=0)
print(submodules)
print(model)

destroy_process_group()

FullyShardedDataParallel(
  (_fsdp_wrapped_module): ModuleList(
    (0): FullyShardedDataParallel(
      (_fsdp_wrapped_module): Linear(in_features=64, out_features=64, bias=True)
    )
  )
)
Sequential(
  (0): Linear(in_features=64, out_features=64, bias=True)
  (1): Linear(in_features=64, out_features=64, bias=True)
  (2): Linear(in_features=64, out_features=64, bias=True)
)


In [243]:
destroy_process_group()

In [252]:
# nn.functional.one_hot(torch.tensor(5), num_classes=10)