In [1]:
#create prompt 
import clip
import numpy as np
import torch
import torch.nn.functional as F
device='cuda:0'
clipmodel, preprocess = clip.load("ViT-B/16", device=device)

text_prompts = []
label=['dog','cat']
for classname in label:
    texts = ['A photo of a '+classname+'.']
    texts = clip.tokenize(texts).to(device)  # tokenize
    class_embeddings = clipmodel.encode_text(texts)
    class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0).detach().cpu()
    text_prompts.append(class_embedding)
text_prompts = torch.stack(text_prompts, dim=1)
text_prompts = text_prompts.transpose(1,0)
# np.save('demo_text_feature.npy',text_prompts)

In [2]:
import timm
print(timm.__version__)

0.4.12


In [4]:
#train demo
import torch
import instruction_ViT
from timm.scheduler.cosine_lr import CosineLRScheduler
import warnings
warnings.filterwarnings('ignore')
import numpy as np
from timm.utils import accuracy, AverageMeter
import time
import datetime
from timm.data import Mixup
import timm


def train_one_epoch(model,epoch,loss_fn,optimizer,train_dataloader,mixup_fn):
    model.train()
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    start = time.time()
    end = time.time()

    for samples,targets in train_dataloader:
        samples=samples.cuda(non_blocking=True)
        targets=targets.cuda(non_blocking=True)
        
        if mixup_fn is not None:
            if len(targets) % 2 == 0:
                samples, targets = mixup_fn(samples, targets)
        output,output2=model(samples)
        loss1=loss_fn(output,targets)
        loss2=loss_fn(output2,targets)
        loss = loss1+loss2
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss_meter.update(loss.item(), targets.size(0))
        batch_time.update(time.time() - end)
        end = time.time()
    epoch_time = time.time() - start
    print(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))} loss {loss_meter.avg:.5f}")


def validate(model,loss_fn,val_dataloader):
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    with torch.no_grad():
        model.eval()
        for images,target in val_dataloader:
            images=images.cuda(non_blocking=True)
            target=target.cuda(non_blocking=True)
            _,output=model(images)
            loss=loss_fn(output,target)
            acc1,_ = accuracy(output, target, topk=(1,2))
            loss_meter.update(loss.item(), target.size(0))
            acc1_meter.update(acc1.item(), target.size(0))
    print(f' * Acc@1 {acc1_meter.avg:.3f}  loss{loss_meter.avg:.5f}')

#create dataloader
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from randaugment import RandAugment
from PIL import Image
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC
def _convert_image_to_rgb(image):
        return image.convert("RGB")
def preprocess_img_train(n_px):
    return Compose([
    Resize(n_px, interpolation=BICUBIC),
    CenterCrop(n_px),
    _convert_image_to_rgb,
    RandAugment(),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
class MyDataset_train(Dataset): #change for your dataset
    def __init__(self,fold_name,lable_feature_list):
        self.data=[]
        self.label=[]
        # label_index=np.arange(len(lable_feature_list))
        # data_path='data_path'
        # for i,fold in enumerate(fold_name):
        #     data_path_ = os.path.join(data_path,fold)
        #     self.data+=[os.path.join(data_path_,j) for j in os.listdir(data_path_)]
        #     self.label+=[label_index[i] for _ in range(len(os.listdir(data_path_)))]
        self.label=[0,0,0,0,0,1,1,1,1,1]
        self.data=[torch.rand((3,224,224)) for _ in range(10)]
        self.transform=preprocess_img_train(224) 
    def __len__(self):
        return len(self.label)
    def __getitem__(self, idx):
        # path=self.data[idx]
        # image=Image.open(path)
        # image = self.transform(image)
        image = self.data[idx]
        label = torch.tensor(self.label[idx])
        return image,label
    
def main(text_features,fold_name):
    model = timm.create_model('instruction_vit_base_patch16_224',pretrained=True,num_classes=text_features.shape[0]).to(device)
    optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)
    lr_schedule=CosineLRScheduler(optimizer=optimizer,t_initial=10,lr_min=1e-5,warmup_t=5)
    loss_fn= torch.nn.CrossEntropyLoss()
    loss_fn=loss_fn.to(device)
    mixup_fn=None
    text_features_temp = text_features.to(device)
    model.reset_prompt(text_features_temp)
    for epoch in range(20):
        mixup_fn = Mixup(
                mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
                prob=0.1, switch_prob=0.5, mode='batch',
                label_smoothing=0.1, num_classes=len(fold_name))
        train_dataset=MyDataset_train(fold_name,text_features_temp)
        dataloader=DataLoader(train_dataset,batch_size=4,drop_last=True)
        train_one_epoch(model,epoch,loss_fn,optimizer,dataloader,mixup_fn)
        validate(model,loss_fn,dataloader)
        lr_schedule.step(epoch)
        torch.cuda.empty_cache()
        # if (epoch+1)%10==0:
        #     torch.save(model, "saved_parameters.pt")

if __name__ == '__main__':
    text_features = text_prompts
    fold_name=label
    seed = 1234
    torch.manual_seed(seed)
    np.random.seed(seed)
    main(text_features,fold_name)

EPOCH 0 training takes 0:00:06 loss 2.48819
 * Acc@1 62.500  loss1.72779
EPOCH 1 training takes 0:00:00 loss 2.39096
 * Acc@1 62.500  loss1.62505
EPOCH 2 training takes 0:00:00 loss 1.61891
 * Acc@1 37.500  loss2.01969
EPOCH 3 training takes 0:00:00 loss 3.18895
 * Acc@1 37.500  loss1.43756
EPOCH 4 training takes 0:00:00 loss 2.96795
 * Acc@1 62.500  loss0.65559
EPOCH 5 training takes 0:00:00 loss 1.35947
 * Acc@1 37.500  loss1.19269
EPOCH 6 training takes 0:00:00 loss 1.90922
 * Acc@1 62.500  loss0.64441
EPOCH 7 training takes 0:00:00 loss 1.61250
 * Acc@1 62.500  loss0.74937
EPOCH 8 training takes 0:00:00 loss 1.45637
 * Acc@1 62.500  loss0.67324
EPOCH 9 training takes 0:00:00 loss 1.34031
 * Acc@1 62.500  loss0.65628
EPOCH 10 training takes 0:00:00 loss 1.35720
 * Acc@1 37.500  loss0.68946
EPOCH 11 training takes 0:00:00 loss 1.40346
 * Acc@1 37.500  loss0.77397
EPOCH 12 training takes 0:00:00 loss 1.61456
 * Acc@1 62.500  loss0.65470
EPOCH 13 training takes 0:00:00 loss 1.51715
 * 