In [1]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

In [2]:
model = SentenceTransformer("Salesforce/SFR-Embedding-Code-400M_R", trust_remote_code=True)

In [7]:
sentences = [
    """package org.littleshoot.proxy;

import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.HttpObject;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.littleshoot.proxy.impl.DefaultHttpProxyServer;
import org.littleshoot.proxy.test.HttpClientUtil;

import javax.net.ssl.SSLException;

import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;

/**
 * This class tests direct requests to the proxy server, which causes endless
 * loops (#205).
 */
public class DirectRequestTest {

    private HttpProxyServer proxyServer;

    @Before
    public void setUp() throws Exception {
        proxyServer = null;
    }

    @After
    public void tearDown() throws Exception {
        if (proxyServer != null) {
            proxyServer.abort();
        }
    }

    @Test(timeout = 5000)
    public void testAnswerBadRequestInsteadOfEndlessLoop() throws Exception {

        startProxyServer();

        int proxyPort = proxyServer.getListenAddress().getPort();
        org.apache.http.HttpResponse response = HttpClientUtil.performHttpGet("http://127.0.0.1:" + proxyPort + "/directToProxy", proxyServer);
        int statusCode = response.getStatusLine().getStatusCode();

        assertEquals("Expected to receive an HTTP 400 from the server", 400, statusCode);
    }

    @Test(timeout = 5000)
    public void testAnswerFromFilterShouldBeServed() throws Exception {

        startProxyServerWithFilterAnsweringStatusCode(403);

        int proxyPort = proxyServer.getListenAddress().getPort();
        org.apache.http.HttpResponse response = HttpClientUtil.performHttpGet("http://localhost:" + proxyPort + "/directToProxy", proxyServer);
        int statusCode = response.getStatusLine().getStatusCode();

        assertEquals("Expected to receive an HTTP 403 from the server", 403, statusCode);
    }

    private void startProxyServerWithFilterAnsweringStatusCode(int statusCode) {
        final HttpResponseStatus status = HttpResponseStatus.valueOf(statusCode);
        HttpFiltersSource filtersSource = new HttpFiltersSourceAdapter() {
            @Override
            public HttpFilters filterRequest(HttpRequest originalRequest) {
                return new HttpFiltersAdapter(originalRequest) {
                    @Override
                    public HttpResponse clientToProxyRequest(HttpObject httpObject) {
                        return new DefaultHttpResponse(HttpVersion.HTTP_1_1, status);
                    }
                };
            }
        };

        proxyServer = DefaultHttpProxyServer.bootstrap()
                .withPort(0)
                .withFiltersSource(filtersSource)
                .start();
    }

    @Test(timeout = 5000)
    public void testHttpsShouldCancelConnection() {
        startProxyServer();

        int proxyPort = proxyServer.getListenAddress().getPort();


        try {
            HttpClientUtil.performHttpGet("https://localhost:" + proxyPort + "/directToProxy", proxyServer);
        } catch (RuntimeException e) {
            Throwable cause = e.getCause();
            assertThat("Expected an SSL exception when attempting to perform an HTTPS GET directly to the proxy", cause, instanceOf(SSLException.class));
        }
    }

    private void startProxyServer() {
        proxyServer = DefaultHttpProxyServer.bootstrap()
                .withPort(0)
                .start();
    }

}""",
"""package org.littleshoot.proxy;

import java.net.InetSocketAddress;
import java.net.Proxy;
import java.net.URL;

import org.eclipse.jetty.server.Server;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.littleshoot.proxy.impl.DefaultHttpProxyServer;

import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.assertThat;
import static org.junit.Assume.assumeFalse;
import static org.junit.Assume.assumeTrue;

/**
 * Note - this test only works on UNIX systems because it checks file descriptor
 * counts.
 */
public class IdleTest {
    private static final int NUMBER_OF_CONNECTIONS_TO_OPEN = 2000;

    private Server webServer;
    private int webServerPort = -1;
    private HttpProxyServer proxyServer;

    @Before
    public void setup() throws Exception {
        assumeTrue("Skipping due to non-Unix OS", TestUtils.isUnixManagementCapable());

        assumeFalse("Skipping for travis-ci build", "true".equals(System.getenv("TRAVIS")));

        webServer = new Server(0);
        webServer.start();
        webServerPort = TestUtils.findLocalHttpPort(webServer);

        proxyServer = DefaultHttpProxyServer.bootstrap()
                .withPort(0)
                .start();
        proxyServer.setIdleConnectionTimeout(10);

    }

    @After
    public void tearDown() throws Exception {
        try {
            if (webServer != null) {
                webServer.stop();
            }
        } finally {
            if (proxyServer != null) {
                proxyServer.abort();
            }
        }
    }

    @Test
    public void testFileDescriptorCount() throws Exception {
        System.out
                .println("------------------ Memory Usage At Beginning ------------------");
        long initialFileDescriptors = TestUtils.getOpenFileDescriptorsAndPrintMemoryUsage();
        Proxy proxy = new Proxy(Proxy.Type.HTTP, new InetSocketAddress(
                "127.0.0.1", proxyServer.getListenAddress().getPort()));
        for (int i = 0; i < NUMBER_OF_CONNECTIONS_TO_OPEN; i++) {
            new URL("http://localhost:" + webServerPort)
                    .openConnection(proxy).connect();
        }

        System.gc();
        System.out
                .println("\n\n------------------ Memory Usage Before Idle Timeout ------------------");

        long fileDescriptorsWhileConnectionsOpen = TestUtils.getOpenFileDescriptorsAndPrintMemoryUsage();
        Thread.sleep(10000);

        System.gc();
        System.out
                .println("\n\n------------------ Memory Usage After Idle Timeout ------------------");
        long fileDescriptorsAfterConnectionsClosed = TestUtils.getOpenFileDescriptorsAndPrintMemoryUsage();

        double fdDeltaToOpen = fileDescriptorsWhileConnectionsOpen
                - initialFileDescriptors;
        double fdDeltaToClosed = fileDescriptorsAfterConnectionsClosed
                - initialFileDescriptors;

        double fdDeltaRatio = fdDeltaToClosed / fdDeltaToOpen;
        assertThat(
                "Number of file descriptors after close should be much closer to initial value than number of file descriptors while open (+ 1%).\n"
                        + "Initial file descriptors: " + initialFileDescriptors + "; file descriptors while connections open: " + fileDescriptorsWhileConnectionsOpen + "; "
                        + "file descriptors after connections closed: " + fileDescriptorsAfterConnectionsClosed + "\n"
                        + "Ratio of file descriptors after connections are closed to descriptors before connections were closed: " + fdDeltaRatio,
                fdDeltaRatio, lessThan(0.01));
    }
}""",
"""/*
 * Copyright (C) 2016 Brett Wooldridge
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.zaxxer.hikari.db;

import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import com.zaxxer.hikari.pool.HikariPool;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;

import static com.zaxxer.hikari.pool.TestElf.getPool;
import static com.zaxxer.hikari.pool.TestElf.newHikariConfig;
import static com.zaxxer.hikari.pool.TestElf.getUnsealedConfig;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

/**
 * @author brettw
 *
 */
public class BasicPoolTest
{
   @Before
   public void setup() throws SQLException
   {
       HikariConfig config = newHikariConfig();
       config.setMinimumIdle(1);
       config.setMaximumPoolSize(2);
       config.setConnectionTestQuery("SELECT 1");
       config.setDataSourceClassName("org.h2.jdbcx.JdbcDataSource");
       config.addDataSourceProperty("url", "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1");

       try (HikariDataSource ds = new HikariDataSource(config);
            Connection conn = ds.getConnection();
            Statement stmt = conn.createStatement()) {
          stmt.execute("DROP TABLE IF EXISTS basic_pool_test");
          stmt.execute("CREATE TABLE basic_pool_test ("
                            + "id INTEGER NOT NULL PRIMARY KEY, "
                            + "timestamp TIMESTAMP, "
                            + "string VARCHAR(128), "
                            + "string_from_number NUMERIC "
                            + ")");
       }
   }

   @Test
   public void testIdleTimeout() throws InterruptedException, SQLException
   {
      HikariConfig config = newHikariConfig();
      config.setMinimumIdle(5);
      config.setMaximumPoolSize(10);
      config.setConnectionTestQuery("SELECT 1");
      config.setDataSourceClassName("org.h2.jdbcx.JdbcDataSource");
      config.addDataSourceProperty("url", "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1");

      System.setProperty("com.zaxxer.hikari.housekeeping.periodMs", "1000");

      try (HikariDataSource ds = new HikariDataSource(config)) {
         getUnsealedConfig(ds).setIdleTimeout(3000);

         System.clearProperty("com.zaxxer.hikari.housekeeping.periodMs");

         SECONDS.sleep(1);

         HikariPool pool = getPool(ds);

         assertEquals("Total connections not as expected", 5, pool.getTotalConnections());
         assertEquals("Idle connections not as expected", 5, pool.getIdleConnections());

         try (Connection connection = ds.getConnection()) {
            Assert.assertNotNull(connection);

            MILLISECONDS.sleep(1500);

            assertEquals("Second total connections not as expected", 6, pool.getTotalConnections());
            assertEquals("Second idle connections not as expected", 5, pool.getIdleConnections());
         }

         assertEquals("Idle connections not as expected", 6, pool.getIdleConnections());

         MILLISECONDS.sleep(3000);

         assertEquals("Third total connections not as expected", 5, pool.getTotalConnections());
         assertEquals("Third idle connections not as expected", 5, pool.getIdleConnections());
      }
   }

   @Test
   public void testIdleTimeout2() throws InterruptedException, SQLException
   {
      HikariConfig config = newHikariConfig();
      config.setMaximumPoolSize(50);
      config.setConnectionTestQuery("SELECT 1");
      config.setDataSourceClassName("org.h2.jdbcx.JdbcDataSource");
      config.addDataSourceProperty("url", "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1");

      System.setProperty("com.zaxxer.hikari.housekeeping.periodMs", "1000");

      try (HikariDataSource ds = new HikariDataSource(config)) {
         System.clearProperty("com.zaxxer.hikari.housekeeping.periodMs");

         SECONDS.sleep(3);

         HikariPool pool = getPool(ds);

         getUnsealedConfig(ds).setIdleTimeout(3000);

         assertEquals("Total connections not as expected", 50, pool.getTotalConnections());
         assertEquals("Idle connections not as expected", 50, pool.getIdleConnections());

         try (Connection connection = ds.getConnection()) {
            assertNotNull(connection);

            MILLISECONDS.sleep(1500);

            assertEquals("Second total connections not as expected", 50, pool.getTotalConnections());
            assertEquals("Second idle connections not as expected", 49, pool.getIdleConnections());
         }

         assertEquals("Idle connections not as expected", 50, pool.getIdleConnections());

         SECONDS.sleep(3);

         assertEquals("Third total connections not as expected", 50, pool.getTotalConnections());
         assertEquals("Third idle connections not as expected", 50, pool.getIdleConnections());
      }
   }
}"""
]

In [8]:
embeddings = model.encode(sentences)

In [9]:
similarities = cos_sim(embeddings[0], embeddings[1:])
print(similarities)

tensor([[0.9012, 0.8625]])


In [6]:
embeddings[0]

array([ 0.25360975, -0.2994772 ,  0.29788098, ..., -0.5058748 ,
       -0.38796282,  0.14717735], shape=(1024,), dtype=float32)