Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using GTR with contrastive learning but the weights are not updated. #2554

Open
mrpeerat opened this issue Mar 23, 2024 · 1 comment
Open

Comments

@mrpeerat
Copy link

Hi!
I'm using sentence-transformers/gtr-t5-base as the base encoder with SimCSE on sentence transformers (this example)
However, I looked at the dev score on the STS-B dev set, and there were no changes to the dev score.
Here is an example:
Screenshot 2024-03-23 at 1 25 39 PM

Here is the code that I use:

model_alls = ['sentence-transformers/gtr-t5-base']

max_seq_length = 512
num_epochs = 4
training_data = 'wiki1m_for_simcse.txt'
sts_dataset_path = 'stsbenchmark.tsv.gz'
pooling_mode = 'mean' # cls,max,mean
wiki_dataset = True

for model_now in model_alls:
        train_batch_size = 32
        learning_rate = 3e-5
        model_name = model_now
        
        
        model_save_path = f"output/Simcse_{model_name.replace('-','_').replace('/','_')}_Size_{output_size}"
    
        word_embedding_model = models.Transformer(model_name, max_seq_length=max_seq_length)
        pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),pooling_mode=pooling_mode)
        model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    
    
        #Read STSbenchmark dataset and use it as development set
        logging.info("Read STSbenchmark dev dataset")
        dev_samples = []
        with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
            reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
            for row in reader:
                if row['split'] == 'dev':
                    score = float(row['score']) / 5.0 #Normalize score to range 0 ... 1
                    dev_samples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=score))
    
        dev_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, batch_size=train_batch_size, name='sts-dev')
    
        train_data = []
        if wiki_dataset:
            print(f"Wiki dataset")
            train_sentences = open(training_data).readlines()
            train_data += [InputExample(texts=[s, s]) for s in train_sentences]
   
        train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size, drop_last=True)
    
    
        logging.info("Train sentences: {}".format(len(train_data)))
    
    
        # We train our model using the MultipleNegativesRankingLoss
        train_loss = losses.MultipleNegativesRankingLoss(model)
        
        evaluation_steps = 125 #Evaluate every 10% of the data
        warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)  # 10% of train data for warm-up
        logging.info("Warmup-steps: {}".format(warmup_steps))
    
        # Train the model
        model.fit(train_objectives=[(train_dataloader, train_loss)],
                  epochs=num_epochs,
                  evaluator=dev_evaluator,
                  evaluation_steps=evaluation_steps,
                  warmup_steps=warmup_steps,
                  optimizer_params={'lr': learning_rate},
                  show_progress_bar=True,
                  output_path=model_save_path,
                  use_amp=True  # Set to True, if your GPU supports FP16 cores
                  )

Thank you in advance.

@tomaarsen
Copy link
Collaborator

Hello!

This took quite a while to debug, but it seems that the combination of fp16 and softmax is a bit unstable, resulting in nan after a Dropout, which then resulted in nan as the loss, preventing further training. See also https://discuss.pytorch.org/t/getting-nans-from-dropout-layer/70693/5
I was able to resolve this issue by training without FP16 by setting use_amp=False. Hope this helps!

  • Tom Aarsen

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants